Lay the groundwork to support history in agent2 (#36483)

Antonio Scandurra created

This pull request introduces title generation and history replaying. We
still need to wire up the rest of the history but this gets us very
close. I extracted a lot of this code from `agent2-history` because that
branch was starting to get long-lived and there were lots of changes
since we started.

Release Notes:

- N/A

Change summary

Cargo.lock                                         |   3 
crates/acp_thread/src/acp_thread.rs                |  39 
crates/acp_thread/src/connection.rs                |  16 
crates/acp_thread/src/diff.rs                      |  13 
crates/acp_thread/src/mention.rs                   |   3 
crates/agent2/Cargo.toml                           |   4 
crates/agent2/src/agent.rs                         | 104 +-
crates/agent2/src/tests/mod.rs                     |  94 +
crates/agent2/src/thread.rs                        | 608 ++++++++++++---
crates/agent2/src/tools/context_server_registry.rs |  10 
crates/agent2/src/tools/edit_file_tool.rs          | 137 ++-
crates/agent2/src/tools/terminal_tool.rs           |   4 
crates/agent2/src/tools/web_search_tool.rs         |  67 +
crates/agent_servers/Cargo.toml                    |   1 
crates/agent_servers/src/acp/v0.rs                 |   4 
crates/agent_servers/src/acp/v1.rs                 |   7 
crates/agent_servers/src/claude.rs                 |  12 
crates/agent_ui/src/acp/thread_view.rs             |  31 
crates/agent_ui/src/agent_diff.rs                  |  41 
19 files changed, 884 insertions(+), 314 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -191,6 +191,7 @@ version = "0.1.0"
 dependencies = [
  "acp_thread",
  "action_log",
+ "agent",
  "agent-client-protocol",
  "agent_servers",
  "agent_settings",
@@ -208,6 +209,7 @@ dependencies = [
  "env_logger 0.11.8",
  "fs",
  "futures 0.3.31",
+ "git",
  "gpui",
  "gpui_tokio",
  "handlebars 4.5.0",
@@ -256,6 +258,7 @@ name = "agent_servers"
 version = "0.1.0"
 dependencies = [
  "acp_thread",
+ "action_log",
  "agent-client-protocol",
  "agent_settings",
  "agentic-coding-protocol",

crates/acp_thread/src/acp_thread.rs 🔗

@@ -537,9 +537,15 @@ impl ToolCallContent {
             acp::ToolCallContent::Content { content } => {
                 Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
             }
-            acp::ToolCallContent::Diff { diff } => {
-                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
-            }
+            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
+                Diff::finalized(
+                    diff.path,
+                    diff.old_text,
+                    diff.new_text,
+                    language_registry,
+                    cx,
+                )
+            })),
         }
     }
 
@@ -682,6 +688,7 @@ pub struct AcpThread {
 #[derive(Debug)]
 pub enum AcpThreadEvent {
     NewEntry,
+    TitleUpdated,
     EntryUpdated(usize),
     EntriesRemoved(Range<usize>),
     ToolAuthorizationRequired,
@@ -728,11 +735,9 @@ impl AcpThread {
         title: impl Into<SharedString>,
         connection: Rc<dyn AgentConnection>,
         project: Entity<Project>,
+        action_log: Entity<ActionLog>,
         session_id: acp::SessionId,
-        cx: &mut Context<Self>,
     ) -> Self {
-        let action_log = cx.new(|_| ActionLog::new(project.clone()));
-
         Self {
             action_log,
             shared_buffers: Default::default(),
@@ -926,6 +931,12 @@ impl AcpThread {
         cx.emit(AcpThreadEvent::NewEntry);
     }
 
+    pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
+        self.title = title;
+        cx.emit(AcpThreadEvent::TitleUpdated);
+        Ok(())
+    }
+
     pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
         cx.emit(AcpThreadEvent::Retry(status));
     }
@@ -1657,7 +1668,7 @@ mod tests {
     use super::*;
     use anyhow::anyhow;
     use futures::{channel::mpsc, future::LocalBoxFuture, select};
-    use gpui::{AsyncApp, TestAppContext, WeakEntity};
+    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
     use indoc::indoc;
     use project::{FakeFs, Fs};
     use rand::Rng as _;
@@ -2327,7 +2338,7 @@ mod tests {
             self: Rc<Self>,
             project: Entity<Project>,
             _cwd: &Path,
-            cx: &mut gpui::App,
+            cx: &mut App,
         ) -> Task<gpui::Result<Entity<AcpThread>>> {
             let session_id = acp::SessionId(
                 rand::thread_rng()
@@ -2337,8 +2348,16 @@ mod tests {
                     .collect::<String>()
                     .into(),
             );
-            let thread =
-                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            let thread = cx.new(|_cx| {
+                AcpThread::new(
+                    "Test",
+                    self.clone(),
+                    project,
+                    action_log,
+                    session_id.clone(),
+                )
+            });
             self.sessions.lock().insert(session_id, thread.downgrade());
             Task::ready(Ok(thread))
         }

crates/acp_thread/src/connection.rs 🔗

@@ -5,11 +5,12 @@ use collections::IndexMap;
 use gpui::{Entity, SharedString, Task};
 use language_model::LanguageModelProviderId;
 use project::Project;
+use serde::{Deserialize, Serialize};
 use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 use ui::{App, IconName};
 use uuid::Uuid;
 
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub struct UserMessageId(Arc<str>);
 
 impl UserMessageId {
@@ -208,6 +209,7 @@ impl AgentModelList {
 mod test_support {
     use std::sync::Arc;
 
+    use action_log::ActionLog;
     use collections::HashMap;
     use futures::{channel::oneshot, future::try_join_all};
     use gpui::{AppContext as _, WeakEntity};
@@ -295,8 +297,16 @@ mod test_support {
             cx: &mut gpui::App,
         ) -> Task<gpui::Result<Entity<AcpThread>>> {
             let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
-            let thread =
-                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
+            let action_log = cx.new(|_| ActionLog::new(project.clone()));
+            let thread = cx.new(|_cx| {
+                AcpThread::new(
+                    "Test",
+                    self.clone(),
+                    project,
+                    action_log,
+                    session_id.clone(),
+                )
+            });
             self.sessions.lock().insert(
                 session_id,
                 Session {

crates/acp_thread/src/diff.rs 🔗

@@ -1,4 +1,3 @@
-use agent_client_protocol as acp;
 use anyhow::Result;
 use buffer_diff::{BufferDiff, BufferDiffSnapshot};
 use editor::{MultiBuffer, PathKey};
@@ -21,17 +20,13 @@ pub enum Diff {
 }
 
 impl Diff {
-    pub fn from_acp(
-        diff: acp::Diff,
+    pub fn finalized(
+        path: PathBuf,
+        old_text: Option<String>,
+        new_text: String,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut Context<Self>,
     ) -> Self {
-        let acp::Diff {
-            path,
-            old_text,
-            new_text,
-        } = diff;
-
         let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 
         let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));

crates/acp_thread/src/mention.rs 🔗

@@ -2,6 +2,7 @@ use agent::ThreadId;
 use anyhow::{Context as _, Result, bail};
 use file_icons::FileIcons;
 use prompt_store::{PromptId, UserPromptId};
+use serde::{Deserialize, Serialize};
 use std::{
     fmt,
     ops::Range,
@@ -11,7 +12,7 @@ use std::{
 use ui::{App, IconName, SharedString};
 use url::Url;
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
 pub enum MentionUri {
     File {
         abs_path: PathBuf,

crates/agent2/Cargo.toml 🔗

@@ -14,6 +14,7 @@ workspace = true
 [dependencies]
 acp_thread.workspace = true
 action_log.workspace = true
+agent.workspace = true
 agent-client-protocol.workspace = true
 agent_servers.workspace = true
 agent_settings.workspace = true
@@ -26,6 +27,7 @@ collections.workspace = true
 context_server.workspace = true
 fs.workspace = true
 futures.workspace = true
+git.workspace = true
 gpui.workspace = true
 handlebars = { workspace = true, features = ["rust-embed"] }
 html_to_markdown.workspace = true
@@ -59,6 +61,7 @@ which.workspace = true
 workspace-hack.workspace = true
 
 [dev-dependencies]
+agent = { workspace = true, "features" = ["test-support"] }
 ctor.workspace = true
 client = { workspace = true, "features" = ["test-support"] }
 clock = { workspace = true, "features" = ["test-support"] }
@@ -66,6 +69,7 @@ context_server = { workspace = true, "features" = ["test-support"] }
 editor = { workspace = true, "features" = ["test-support"] }
 env_logger.workspace = true
 fs = { workspace = true, "features" = ["test-support"] }
+git = { workspace = true, "features" = ["test-support"] }
 gpui = { workspace = true, "features" = ["test-support"] }
 gpui_tokio.workspace = true
 language = { workspace = true, "features" = ["test-support"] }

crates/agent2/src/agent.rs 🔗

@@ -1,10 +1,11 @@
 use crate::{
-    AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
-    DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
-    MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
-    ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
+    ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
+    EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
+    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
+    UserMessageContent, WebSearchTool, templates::Templates,
 };
 use acp_thread::AgentModelSelector;
+use action_log::ActionLog;
 use agent_client_protocol as acp;
 use agent_settings::AgentSettings;
 use anyhow::{Context as _, Result, anyhow};
@@ -427,18 +428,19 @@ impl NativeAgent {
     ) {
         self.models.refresh_list(cx);
 
-        let default_model = LanguageModelRegistry::read_global(cx)
-            .default_model()
-            .map(|m| m.model.clone());
+        let registry = LanguageModelRegistry::read_global(cx);
+        let default_model = registry.default_model().map(|m| m.model.clone());
+        let summarization_model = registry.thread_summary_model().map(|m| m.model.clone());
 
         for session in self.sessions.values_mut() {
             session.thread.update(cx, |thread, cx| {
                 if thread.model().is_none()
                     && let Some(model) = default_model.clone()
                 {
-                    thread.set_model(model);
+                    thread.set_model(model, cx);
                     cx.notify();
                 }
+                thread.set_summarization_model(summarization_model.clone(), cx);
             });
         }
     }
@@ -462,10 +464,7 @@ impl NativeAgentConnection {
         session_id: acp::SessionId,
         cx: &mut App,
         f: impl 'static
-        + FnOnce(
-            Entity<Thread>,
-            &mut App,
-        ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+        + FnOnce(Entity<Thread>, &mut App) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>,
     ) -> Task<Result<acp::PromptResponse>> {
         let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
             agent
@@ -489,7 +488,18 @@ impl NativeAgentConnection {
                         log::trace!("Received completion event: {:?}", event);
 
                         match event {
-                            AgentResponseEvent::Text(text) => {
+                            ThreadEvent::UserMessage(message) => {
+                                acp_thread.update(cx, |thread, cx| {
+                                    for content in message.content {
+                                        thread.push_user_content_block(
+                                            Some(message.id.clone()),
+                                            content.into(),
+                                            cx,
+                                        );
+                                    }
+                                })?;
+                            }
+                            ThreadEvent::AgentText(text) => {
                                 acp_thread.update(cx, |thread, cx| {
                                     thread.push_assistant_content_block(
                                         acp::ContentBlock::Text(acp::TextContent {
@@ -501,7 +511,7 @@ impl NativeAgentConnection {
                                     )
                                 })?;
                             }
-                            AgentResponseEvent::Thinking(text) => {
+                            ThreadEvent::AgentThinking(text) => {
                                 acp_thread.update(cx, |thread, cx| {
                                     thread.push_assistant_content_block(
                                         acp::ContentBlock::Text(acp::TextContent {
@@ -513,7 +523,7 @@ impl NativeAgentConnection {
                                     )
                                 })?;
                             }
-                            AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+                            ThreadEvent::ToolCallAuthorization(ToolCallAuthorization {
                                 tool_call,
                                 options,
                                 response,
@@ -536,22 +546,26 @@ impl NativeAgentConnection {
                                 })
                                 .detach();
                             }
-                            AgentResponseEvent::ToolCall(tool_call) => {
+                            ThreadEvent::ToolCall(tool_call) => {
                                 acp_thread.update(cx, |thread, cx| {
                                     thread.upsert_tool_call(tool_call, cx)
                                 })??;
                             }
-                            AgentResponseEvent::ToolCallUpdate(update) => {
+                            ThreadEvent::ToolCallUpdate(update) => {
                                 acp_thread.update(cx, |thread, cx| {
                                     thread.update_tool_call(update, cx)
                                 })??;
                             }
-                            AgentResponseEvent::Retry(status) => {
+                            ThreadEvent::TitleUpdate(title) => {
+                                acp_thread
+                                    .update(cx, |thread, cx| thread.update_title(title, cx))??;
+                            }
+                            ThreadEvent::Retry(status) => {
                                 acp_thread.update(cx, |thread, cx| {
                                     thread.update_retry_status(status, cx)
                                 })?;
                             }
-                            AgentResponseEvent::Stop(stop_reason) => {
+                            ThreadEvent::Stop(stop_reason) => {
                                 log::debug!("Assistant message complete: {:?}", stop_reason);
                                 return Ok(acp::PromptResponse { stop_reason });
                             }
@@ -604,8 +618,8 @@ impl AgentModelSelector for NativeAgentConnection {
             return Task::ready(Err(anyhow!("Invalid model ID {}", model_id)));
         };
 
-        thread.update(cx, |thread, _cx| {
-            thread.set_model(model.clone());
+        thread.update(cx, |thread, cx| {
+            thread.set_model(model.clone(), cx);
         });
 
         update_settings_file::<AgentSettings>(
@@ -665,30 +679,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         cx.spawn(async move |cx| {
             log::debug!("Starting thread creation in async context");
 
-            // Generate session ID
-            let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
-            log::info!("Created session with ID: {}", session_id);
-
-            // Create AcpThread
-            let acp_thread = cx.update(|cx| {
-                cx.new(|cx| {
-                    acp_thread::AcpThread::new(
-                        "agent2",
-                        self.clone(),
-                        project.clone(),
-                        session_id.clone(),
-                        cx,
-                    )
-                })
-            })?;
-            let action_log = cx.update(|cx| acp_thread.read(cx).action_log().clone())?;
-
+            let action_log = cx.new(|_cx| ActionLog::new(project.clone()))?;
             // Create Thread
             let thread = agent.update(
                 cx,
                 |agent, cx: &mut gpui::Context<NativeAgent>| -> Result<_> {
                     // Fetch default model from registry settings
                     let registry = LanguageModelRegistry::read_global(cx);
+                    let language_registry = project.read(cx).languages().clone();
 
                     // Log available models for debugging
                     let available_count = registry.available_models(cx).count();
@@ -699,6 +697,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                             .models
                             .model_from_id(&LanguageModels::model_id(&default_model.model))
                     });
+                    let summarization_model = registry.thread_summary_model().map(|c| c.model);
 
                     let thread = cx.new(|cx| {
                         let mut thread = Thread::new(
@@ -708,13 +707,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                             action_log.clone(),
                             agent.templates.clone(),
                             default_model,
+                            summarization_model,
                             cx,
                         );
                         thread.add_tool(CopyPathTool::new(project.clone()));
                         thread.add_tool(CreateDirectoryTool::new(project.clone()));
                         thread.add_tool(DeletePathTool::new(project.clone(), action_log.clone()));
                         thread.add_tool(DiagnosticsTool::new(project.clone()));
-                        thread.add_tool(EditFileTool::new(cx.entity()));
+                        thread.add_tool(EditFileTool::new(cx.weak_entity(), language_registry));
                         thread.add_tool(FetchTool::new(project.read(cx).client().http_client()));
                         thread.add_tool(FindPathTool::new(project.clone()));
                         thread.add_tool(GrepTool::new(project.clone()));
@@ -722,7 +722,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                         thread.add_tool(MovePathTool::new(project.clone()));
                         thread.add_tool(NowTool);
                         thread.add_tool(OpenTool::new(project.clone()));
-                        thread.add_tool(ReadFileTool::new(project.clone(), action_log));
+                        thread.add_tool(ReadFileTool::new(project.clone(), action_log.clone()));
                         thread.add_tool(TerminalTool::new(project.clone(), cx));
                         thread.add_tool(ThinkingTool);
                         thread.add_tool(WebSearchTool); // TODO: Enable this only if it's a zed model.
@@ -733,6 +733,21 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                 },
             )??;
 
+            let session_id = thread.read_with(cx, |thread, _| thread.id().clone())?;
+            log::info!("Created session with ID: {}", session_id);
+            // Create AcpThread
+            let acp_thread = cx.update(|cx| {
+                cx.new(|_cx| {
+                    acp_thread::AcpThread::new(
+                        "agent2",
+                        self.clone(),
+                        project.clone(),
+                        action_log.clone(),
+                        session_id.clone(),
+                    )
+                })
+            })?;
+
             // Store the session
             agent.update(cx, |agent, cx| {
                 agent.sessions.insert(
@@ -803,7 +818,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         log::info!("Cancelling on session: {}", session_id);
         self.0.update(cx, |agent, cx| {
             if let Some(agent) = agent.sessions.get(session_id) {
-                agent.thread.update(cx, |thread, _cx| thread.cancel());
+                agent.thread.update(cx, |thread, cx| thread.cancel(cx));
             }
         });
     }
@@ -830,7 +845,10 @@ struct NativeAgentSessionEditor(Entity<Thread>);
 
 impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
     fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
-        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+        Task::ready(
+            self.0
+                .update(cx, |thread, cx| thread.truncate(message_id, cx)),
+        )
     }
 }
 

crates/agent2/src/tests/mod.rs 🔗

@@ -345,7 +345,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 
     let mut saw_partial_tool_use = false;
     while let Some(event) = events.next().await {
-        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
+        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
             thread.update(cx, |thread, _cx| {
                 // Look for a tool use in the thread's last message
                 let message = thread.last_message().unwrap();
@@ -735,16 +735,14 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
     );
 }
 
-async fn expect_tool_call(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
-) -> acp::ToolCall {
+async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
     let event = events
         .next()
         .await
         .expect("no tool call authorization event received")
         .unwrap();
     match event {
-        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
+        ThreadEvent::ToolCall(tool_call) => return tool_call,
         event => {
             panic!("Unexpected event {event:?}");
         }
@@ -752,7 +750,7 @@ async fn expect_tool_call(
 }
 
 async fn expect_tool_call_update_fields(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
+    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 ) -> acp::ToolCallUpdate {
     let event = events
         .next()
@@ -760,7 +758,7 @@ async fn expect_tool_call_update_fields(
         .expect("no tool call authorization event received")
         .unwrap();
     match event {
-        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
+        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
             return update;
         }
         event => {
@@ -770,7 +768,7 @@ async fn expect_tool_call_update_fields(
 }
 
 async fn next_tool_call_authorization(
-    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
+    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 ) -> ToolCallAuthorization {
     loop {
         let event = events
@@ -778,7 +776,7 @@ async fn next_tool_call_authorization(
             .await
             .expect("no tool call authorization event received")
             .unwrap();
-        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
+        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
             let permission_kinds = tool_call_authorization
                 .options
                 .iter()
@@ -945,13 +943,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
     let mut echo_completed = false;
     while let Some(event) = events.next().await {
         match event.unwrap() {
-            AgentResponseEvent::ToolCall(tool_call) => {
+            ThreadEvent::ToolCall(tool_call) => {
                 assert_eq!(tool_call.title, expected_tools.remove(0));
                 if tool_call.title == "Echo" {
                     echo_id = Some(tool_call.id);
                 }
             }
-            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
                 acp::ToolCallUpdate {
                     id,
                     fields:
@@ -973,13 +971,13 @@ async fn test_cancellation(cx: &mut TestAppContext) {
 
     // Cancel the current send and ensure that the event stream is closed, even
     // if one of the tools is still running.
-    thread.update(cx, |thread, _cx| thread.cancel());
+    thread.update(cx, |thread, cx| thread.cancel(cx));
     let events = events.collect::<Vec<_>>().await;
     let last_event = events.last();
     assert!(
         matches!(
             last_event,
-            Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
+            Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
         ),
         "unexpected event {last_event:?}"
     );
@@ -1161,7 +1159,7 @@ async fn test_truncate(cx: &mut TestAppContext) {
     });
 
     thread
-        .update(cx, |thread, _cx| thread.truncate(message_id))
+        .update(cx, |thread, cx| thread.truncate(message_id, cx))
         .unwrap();
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
@@ -1203,6 +1201,51 @@ async fn test_truncate(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_title_generation(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let summary_model = Arc::new(FakeLanguageModel::default());
+    thread.update(cx, |thread, cx| {
+        thread.set_summarization_model(Some(summary_model.clone()), cx)
+    });
+
+    let send = thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Hello"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    fake_model.send_last_completion_stream_text_chunk("Hey!");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
+
+    // Ensure the summary model has been invoked to generate a title.
+    summary_model.send_last_completion_stream_text_chunk("Hello ");
+    summary_model.send_last_completion_stream_text_chunk("world\nG");
+    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
+    summary_model.end_last_completion_stream();
+    send.collect::<Vec<_>>().await;
+    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+
+    // Send another message, ensuring no title is generated this time.
+    let send = thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Hello again"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Hey again!");
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+    assert_eq!(summary_model.pending_completions(), Vec::new());
+    send.collect::<Vec<_>>().await;
+    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
+}
+
 #[gpui::test]
 async fn test_agent_connection(cx: &mut TestAppContext) {
     cx.update(settings::init);
@@ -1442,7 +1485,7 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
             thread.send(UserMessageId::new(), ["Hello!"], cx)
         })
         .unwrap();
@@ -1454,10 +1497,10 @@ async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
     let mut retry_events = Vec::new();
     while let Some(Ok(event)) = events.next().await {
         match event {
-            AgentResponseEvent::Retry(retry_status) => {
+            ThreadEvent::Retry(retry_status) => {
                 retry_events.push(retry_status);
             }
-            AgentResponseEvent::Stop(..) => break,
+            ThreadEvent::Stop(..) => break,
             _ => {}
         }
     }
@@ -1486,7 +1529,7 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
             thread.send(UserMessageId::new(), ["Hello!"], cx)
         })
         .unwrap();
@@ -1507,10 +1550,10 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
     let mut retry_events = Vec::new();
     while let Some(Ok(event)) = events.next().await {
         match event {
-            AgentResponseEvent::Retry(retry_status) => {
+            ThreadEvent::Retry(retry_status) => {
                 retry_events.push(retry_status);
             }
-            AgentResponseEvent::Stop(..) => break,
+            ThreadEvent::Stop(..) => break,
             _ => {}
         }
     }
@@ -1543,7 +1586,7 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
 
     let mut events = thread
         .update(cx, |thread, cx| {
-            thread.set_completion_mode(agent_settings::CompletionMode::Burn);
+            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
             thread.send(UserMessageId::new(), ["Hello!"], cx)
         })
         .unwrap();
@@ -1565,10 +1608,10 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
     let mut retry_events = Vec::new();
     while let Some(event) = events.next().await {
         match event {
-            Ok(AgentResponseEvent::Retry(retry_status)) => {
+            Ok(ThreadEvent::Retry(retry_status)) => {
                 retry_events.push(retry_status);
             }
-            Ok(AgentResponseEvent::Stop(..)) => break,
+            Ok(ThreadEvent::Stop(..)) => break,
             Err(error) => errors.push(error),
             _ => {}
         }
@@ -1592,11 +1635,11 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
 }
 
 /// Filters out the stop events for asserting against in tests
-fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
+fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
     result_events
         .into_iter()
         .filter_map(|event| match event.unwrap() {
-            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
+            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
             _ => None,
         })
         .collect()
@@ -1713,6 +1756,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
             action_log,
             templates,
             Some(model.clone()),
+            None,
             cx,
         )
     });

crates/agent2/src/thread.rs 🔗

@@ -1,25 +1,34 @@
 use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
 use acp_thread::{MentionUri, UserMessageId};
 use action_log::ActionLog;
+use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
 use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
+use chrono::{DateTime, Utc};
 use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
 use collections::IndexMap;
 use fs::Fs;
 use futures::{
+    FutureExt,
     channel::{mpsc, oneshot},
+    future::Shared,
     stream::FuturesUnordered,
 };
+use git::repository::DiffType;
 use gpui::{App, AsyncApp, Context, Entity, SharedString, Task, WeakEntity};
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
     LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
     LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
     LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+    TokenUsage,
+};
+use project::{
+    Project,
+    git_store::{GitStore, RepositoryState},
 };
-use project::Project;
 use prompt_store::ProjectContext;
 use schemars::{JsonSchema, Schema};
 use serde::{Deserialize, Serialize};
@@ -35,28 +44,7 @@ use std::{fmt::Write, ops::Range};
 use util::{ResultExt, markdown::MarkdownCodeBlock};
 use uuid::Uuid;
 
-#[derive(
-    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
-)]
-pub struct ThreadId(Arc<str>);
-
-impl ThreadId {
-    pub fn new() -> Self {
-        Self(Uuid::new_v4().to_string().into())
-    }
-}
-
-impl std::fmt::Display for ThreadId {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(f, "{}", self.0)
-    }
-}
-
-impl From<&str> for ThreadId {
-    fn from(value: &str) -> Self {
-        Self(value.into())
-    }
-}
+const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user";
 
 /// The ID of the user prompt that initiated a request.
 ///
@@ -91,7 +79,7 @@ enum RetryStrategy {
     },
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub enum Message {
     User(UserMessage),
     Agent(AgentMessage),
@@ -106,6 +94,18 @@ impl Message {
         }
     }
 
+    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
+        match self {
+            Message::User(message) => vec![message.to_request()],
+            Message::Agent(message) => message.to_request(),
+            Message::Resume => vec![LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Continue where you left off".into()],
+                cache: false,
+            }],
+        }
+    }
+
     pub fn to_markdown(&self) -> String {
         match self {
             Message::User(message) => message.to_markdown(),
@@ -113,15 +113,22 @@ impl Message {
             Message::Resume => "[resumed after tool use limit was reached]".into(),
         }
     }
+
+    pub fn role(&self) -> Role {
+        match self {
+            Message::User(_) | Message::Resume => Role::User,
+            Message::Agent(_) => Role::Assistant,
+        }
+    }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub struct UserMessage {
     pub id: UserMessageId,
     pub content: Vec<UserMessageContent>,
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub enum UserMessageContent {
     Text(String),
     Mention { uri: MentionUri, content: String },
@@ -345,9 +352,6 @@ impl AgentMessage {
                 AgentMessageContent::RedactedThinking(_) => {
                     markdown.push_str("<redacted_thinking />\n")
                 }
-                AgentMessageContent::Image(_) => {
-                    markdown.push_str("<image />\n");
-                }
                 AgentMessageContent::ToolUse(tool_use) => {
                     markdown.push_str(&format!(
                         "**Tool Use**: {} (ID: {})\n",
@@ -418,9 +422,6 @@ impl AgentMessage {
                 AgentMessageContent::ToolUse(value) => {
                     language_model::MessageContent::ToolUse(value.clone())
                 }
-                AgentMessageContent::Image(value) => {
-                    language_model::MessageContent::Image(value.clone())
-                }
             };
             assistant_message.content.push(chunk);
         }
@@ -450,13 +451,13 @@ impl AgentMessage {
     }
 }
 
-#[derive(Default, Debug, Clone, PartialEq, Eq)]
+#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub struct AgentMessage {
     pub content: Vec<AgentMessageContent>,
     pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub enum AgentMessageContent {
     Text(String),
     Thinking {
@@ -464,17 +465,18 @@ pub enum AgentMessageContent {
         signature: Option<String>,
     },
     RedactedThinking(String),
-    Image(LanguageModelImage),
     ToolUse(LanguageModelToolUse),
 }
 
 #[derive(Debug)]
-pub enum AgentResponseEvent {
-    Text(String),
-    Thinking(String),
+pub enum ThreadEvent {
+    UserMessage(UserMessage),
+    AgentText(String),
+    AgentThinking(String),
     ToolCall(acp::ToolCall),
     ToolCallUpdate(acp_thread::ToolCallUpdate),
     ToolCallAuthorization(ToolCallAuthorization),
+    TitleUpdate(SharedString),
     Retry(acp_thread::RetryStatus),
     Stop(acp::StopReason),
 }
@@ -487,8 +489,12 @@ pub struct ToolCallAuthorization {
 }
 
 pub struct Thread {
-    id: ThreadId,
+    id: acp::SessionId,
     prompt_id: PromptId,
+    updated_at: DateTime<Utc>,
+    title: Option<SharedString>,
+    #[allow(unused)]
+    summary: DetailedSummaryState,
     messages: Vec<Message>,
     completion_mode: CompletionMode,
     /// Holds the task that handles agent interaction until the end of the turn.
@@ -498,11 +504,18 @@ pub struct Thread {
     pending_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     tool_use_limit_reached: bool,
+    #[allow(unused)]
+    request_token_usage: Vec<TokenUsage>,
+    #[allow(unused)]
+    cumulative_token_usage: TokenUsage,
+    #[allow(unused)]
+    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
     context_server_registry: Entity<ContextServerRegistry>,
     profile_id: AgentProfileId,
     project_context: Entity<ProjectContext>,
     templates: Arc<Templates>,
     model: Option<Arc<dyn LanguageModel>>,
+    summarization_model: Option<Arc<dyn LanguageModel>>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
 }
@@ -515,36 +528,254 @@ impl Thread {
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
         model: Option<Arc<dyn LanguageModel>>,
+        summarization_model: Option<Arc<dyn LanguageModel>>,
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
         Self {
-            id: ThreadId::new(),
+            id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()),
             prompt_id: PromptId::new(),
+            updated_at: Utc::now(),
+            title: None,
+            summary: DetailedSummaryState::default(),
             messages: Vec::new(),
             completion_mode: AgentSettings::get_global(cx).preferred_completion_mode,
             running_turn: None,
             pending_message: None,
             tools: BTreeMap::default(),
             tool_use_limit_reached: false,
+            request_token_usage: Vec::new(),
+            cumulative_token_usage: TokenUsage::default(),
+            initial_project_snapshot: {
+                let project_snapshot = Self::project_snapshot(project.clone(), cx);
+                cx.foreground_executor()
+                    .spawn(async move { Some(project_snapshot.await) })
+                    .shared()
+            },
             context_server_registry,
             profile_id,
             project_context,
             templates,
             model,
+            summarization_model,
             project,
             action_log,
         }
     }
 
-    pub fn project(&self) -> &Entity<Project> {
-        &self.project
+    pub fn id(&self) -> &acp::SessionId {
+        &self.id
+    }
+
+    pub fn replay(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
+        let (tx, rx) = mpsc::unbounded();
+        let stream = ThreadEventStream(tx);
+        for message in &self.messages {
+            match message {
+                Message::User(user_message) => stream.send_user_message(user_message),
+                Message::Agent(assistant_message) => {
+                    for content in &assistant_message.content {
+                        match content {
+                            AgentMessageContent::Text(text) => stream.send_text(text),
+                            AgentMessageContent::Thinking { text, .. } => {
+                                stream.send_thinking(text)
+                            }
+                            AgentMessageContent::RedactedThinking(_) => {}
+                            AgentMessageContent::ToolUse(tool_use) => {
+                                self.replay_tool_call(
+                                    tool_use,
+                                    assistant_message.tool_results.get(&tool_use.id),
+                                    &stream,
+                                    cx,
+                                );
+                            }
+                        }
+                    }
+                }
+                Message::Resume => {}
+            }
+        }
+        rx
+    }
+
+    fn replay_tool_call(
+        &self,
+        tool_use: &LanguageModelToolUse,
+        tool_result: Option<&LanguageModelToolResult>,
+        stream: &ThreadEventStream,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(tool) = self.tools.get(tool_use.name.as_ref()) else {
+            stream
+                .0
+                .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall {
+                    id: acp::ToolCallId(tool_use.id.to_string().into()),
+                    title: tool_use.name.to_string(),
+                    kind: acp::ToolKind::Other,
+                    status: acp::ToolCallStatus::Failed,
+                    content: Vec::new(),
+                    locations: Vec::new(),
+                    raw_input: Some(tool_use.input.clone()),
+                    raw_output: None,
+                })))
+                .ok();
+            return;
+        };
+
+        let title = tool.initial_title(tool_use.input.clone());
+        let kind = tool.kind();
+        stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
+
+        let output = tool_result
+            .as_ref()
+            .and_then(|result| result.output.clone());
+        if let Some(output) = output.clone() {
+            let tool_event_stream = ToolCallEventStream::new(
+                tool_use.id.clone(),
+                stream.clone(),
+                Some(self.project.read(cx).fs().clone()),
+            );
+            tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
+                .log_err();
+        }
+
+        stream.update_tool_call_fields(
+            &tool_use.id,
+            acp::ToolCallUpdateFields {
+                status: Some(acp::ToolCallStatus::Completed),
+                raw_output: output,
+                ..Default::default()
+            },
+        );
+    }
+
+    /// Create a snapshot of the current project state including git information and unsaved buffers.
+    fn project_snapshot(
+        project: Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> Task<Arc<agent::thread::ProjectSnapshot>> {
+        let git_store = project.read(cx).git_store().clone();
+        let worktree_snapshots: Vec<_> = project
+            .read(cx)
+            .visible_worktrees(cx)
+            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
+            .collect();
+
+        cx.spawn(async move |_, cx| {
+            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
+
+            let mut unsaved_buffers = Vec::new();
+            cx.update(|app_cx| {
+                let buffer_store = project.read(app_cx).buffer_store();
+                for buffer_handle in buffer_store.read(app_cx).buffers() {
+                    let buffer = buffer_handle.read(app_cx);
+                    if buffer.is_dirty()
+                        && let Some(file) = buffer.file()
+                    {
+                        let path = file.path().to_string_lossy().to_string();
+                        unsaved_buffers.push(path);
+                    }
+                }
+            })
+            .ok();
+
+            Arc::new(ProjectSnapshot {
+                worktree_snapshots,
+                unsaved_buffer_paths: unsaved_buffers,
+                timestamp: Utc::now(),
+            })
+        })
+    }
+
+    fn worktree_snapshot(
+        worktree: Entity<project::Worktree>,
+        git_store: Entity<GitStore>,
+        cx: &App,
+    ) -> Task<agent::thread::WorktreeSnapshot> {
+        cx.spawn(async move |cx| {
+            // Get worktree path and snapshot
+            let worktree_info = cx.update(|app_cx| {
+                let worktree = worktree.read(app_cx);
+                let path = worktree.abs_path().to_string_lossy().to_string();
+                let snapshot = worktree.snapshot();
+                (path, snapshot)
+            });
+
+            let Ok((worktree_path, _snapshot)) = worktree_info else {
+                return WorktreeSnapshot {
+                    worktree_path: String::new(),
+                    git_state: None,
+                };
+            };
+
+            let git_state = git_store
+                .update(cx, |git_store, cx| {
+                    git_store
+                        .repositories()
+                        .values()
+                        .find(|repo| {
+                            repo.read(cx)
+                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
+                                .is_some()
+                        })
+                        .cloned()
+                })
+                .ok()
+                .flatten()
+                .map(|repo| {
+                    repo.update(cx, |repo, _| {
+                        let current_branch =
+                            repo.branch.as_ref().map(|branch| branch.name().to_owned());
+                        repo.send_job(None, |state, _| async move {
+                            let RepositoryState::Local { backend, .. } = state else {
+                                return GitState {
+                                    remote_url: None,
+                                    head_sha: None,
+                                    current_branch,
+                                    diff: None,
+                                };
+                            };
+
+                            let remote_url = backend.remote_url("origin");
+                            let head_sha = backend.head_sha().await;
+                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
+
+                            GitState {
+                                remote_url,
+                                head_sha,
+                                current_branch,
+                                diff,
+                            }
+                        })
+                    })
+                });
+
+            let git_state = match git_state {
+                Some(git_state) => match git_state.ok() {
+                    Some(git_state) => git_state.await.ok(),
+                    None => None,
+                },
+                None => None,
+            };
+
+            WorktreeSnapshot {
+                worktree_path,
+                git_state,
+            }
+        })
     }
 
     pub fn project_context(&self) -> &Entity<ProjectContext> {
         &self.project_context
     }
 
+    pub fn project(&self) -> &Entity<Project> {
+        &self.project
+    }
+
     pub fn action_log(&self) -> &Entity<ActionLog> {
         &self.action_log
     }
@@ -553,16 +784,27 @@ impl Thread {
         self.model.as_ref()
     }
 
-    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
         self.model = Some(model);
+        cx.notify()
+    }
+
+    pub fn set_summarization_model(
+        &mut self,
+        model: Option<Arc<dyn LanguageModel>>,
+        cx: &mut Context<Self>,
+    ) {
+        self.summarization_model = model;
+        cx.notify()
     }
 
     pub fn completion_mode(&self) -> CompletionMode {
         self.completion_mode
     }
 
-    pub fn set_completion_mode(&mut self, mode: CompletionMode) {
+    pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context<Self>) {
         self.completion_mode = mode;
+        cx.notify()
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -590,29 +832,29 @@ impl Thread {
         self.profile_id = profile_id;
     }
 
-    pub fn cancel(&mut self) {
+    pub fn cancel(&mut self, cx: &mut Context<Self>) {
         if let Some(running_turn) = self.running_turn.take() {
             running_turn.cancel();
         }
-        self.flush_pending_message();
+        self.flush_pending_message(cx);
     }
 
-    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
-        self.cancel();
+    pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context<Self>) -> Result<()> {
+        self.cancel(cx);
         let Some(position) = self.messages.iter().position(
             |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
         ) else {
             return Err(anyhow!("Message not found"));
         };
         self.messages.truncate(position);
+        cx.notify();
         Ok(())
     }
 
     pub fn resume(
         &mut self,
         cx: &mut Context<Self>,
-    ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
-        anyhow::ensure!(self.model.is_some(), "Model not set");
+    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
         anyhow::ensure!(
             self.tool_use_limit_reached,
             "can only resume after tool use limit is reached"
@@ -633,7 +875,7 @@ impl Thread {
         id: UserMessageId,
         content: impl IntoIterator<Item = T>,
         cx: &mut Context<Self>,
-    ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>
+    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
     where
         T: Into<UserMessageContent>,
     {
@@ -656,22 +898,19 @@ impl Thread {
     fn run_turn(
         &mut self,
         cx: &mut Context<Self>,
-    ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
-        self.cancel();
-
-        let model = self
-            .model()
-            .cloned()
-            .context("No language model configured")?;
-        let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
-        let event_stream = AgentResponseEventStream(events_tx);
+    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
+        self.cancel(cx);
+
+        let model = self.model.clone().context("No language model configured")?;
+        let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
+        let event_stream = ThreadEventStream(events_tx);
         let message_ix = self.messages.len().saturating_sub(1);
         self.tool_use_limit_reached = false;
         self.running_turn = Some(RunningTurn {
             event_stream: event_stream.clone(),
             _task: cx.spawn(async move |this, cx| {
                 log::info!("Starting agent turn execution");
-                let turn_result: Result<()> = async {
+                let turn_result: Result<StopReason> = async {
                     let mut completion_intent = CompletionIntent::UserPrompt;
                     loop {
                         log::debug!(
@@ -685,18 +924,27 @@ impl Thread {
                         log::info!("Calling model.stream_completion");
 
                         let mut tool_use_limit_reached = false;
+                        let mut refused = false;
+                        let mut reached_max_tokens = false;
                         let mut tool_uses = Self::stream_completion_with_retries(
                             this.clone(),
                             model.clone(),
                             request,
-                            message_ix,
                             &event_stream,
                             &mut tool_use_limit_reached,
+                            &mut refused,
+                            &mut reached_max_tokens,
                             cx,
                         )
                         .await?;
 
-                        let used_tools = tool_uses.is_empty();
+                        if refused {
+                            return Ok(StopReason::Refusal);
+                        } else if reached_max_tokens {
+                            return Ok(StopReason::MaxTokens);
+                        }
+
+                        let end_turn = tool_uses.is_empty();
                         while let Some(tool_result) = tool_uses.next().await {
                             log::info!("Tool finished {:?}", tool_result);
 
@@ -724,29 +972,42 @@ impl Thread {
                             log::info!("Tool use limit reached, completing turn");
                             this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
                             return Err(language_model::ToolUseLimitReachedError.into());
-                        } else if used_tools {
+                        } else if end_turn {
                             log::info!("No tool uses found, completing turn");
-                            return Ok(());
+                            return Ok(StopReason::EndTurn);
                         } else {
-                            this.update(cx, |this, _| this.flush_pending_message())?;
+                            this.update(cx, |this, cx| this.flush_pending_message(cx))?;
                             completion_intent = CompletionIntent::ToolResults;
                         }
                     }
                 }
                 .await;
+                _ = this.update(cx, |this, cx| this.flush_pending_message(cx));
+
+                match turn_result {
+                    Ok(reason) => {
+                        log::info!("Turn execution completed: {:?}", reason);
+
+                        let update_title = this
+                            .update(cx, |this, cx| this.update_title(&event_stream, cx))
+                            .ok()
+                            .flatten();
+                        if let Some(update_title) = update_title {
+                            update_title.await.context("update title failed").log_err();
+                        }
 
-                if let Err(error) = turn_result {
-                    log::error!("Turn execution failed: {:?}", error);
-                    event_stream.send_error(error);
-                } else {
-                    log::info!("Turn execution completed successfully");
+                        event_stream.send_stop(reason);
+                        if reason == StopReason::Refusal {
+                            _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
+                        }
+                    }
+                    Err(error) => {
+                        log::error!("Turn execution failed: {:?}", error);
+                        event_stream.send_error(error);
+                    }
                 }
 
-                this.update(cx, |this, _| {
-                    this.flush_pending_message();
-                    this.running_turn.take();
-                })
-                .ok();
+                _ = this.update(cx, |this, _| this.running_turn.take());
             }),
         });
         Ok(events_rx)
@@ -756,9 +1017,10 @@ impl Thread {
         this: WeakEntity<Self>,
         model: Arc<dyn LanguageModel>,
         request: LanguageModelRequest,
-        message_ix: usize,
-        event_stream: &AgentResponseEventStream,
+        event_stream: &ThreadEventStream,
         tool_use_limit_reached: &mut bool,
+        refusal: &mut bool,
+        max_tokens_reached: &mut bool,
         cx: &mut AsyncApp,
     ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
         log::debug!("Stream completion started successfully");
@@ -774,16 +1036,17 @@ impl Thread {
                     )) => {
                         *tool_use_limit_reached = true;
                     }
-                    Ok(LanguageModelCompletionEvent::Stop(reason)) => {
-                        event_stream.send_stop(reason);
-                        if reason == StopReason::Refusal {
-                            this.update(cx, |this, _cx| {
-                                this.flush_pending_message();
-                                this.messages.truncate(message_ix);
-                            })?;
-                            return Ok(tool_uses);
-                        }
+                    Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
+                        *refusal = true;
+                        return Ok(FuturesUnordered::default());
+                    }
+                    Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
+                        *max_tokens_reached = true;
+                        return Ok(FuturesUnordered::default());
                     }
+                    Ok(LanguageModelCompletionEvent::Stop(
+                        StopReason::ToolUse | StopReason::EndTurn,
+                    )) => break,
                     Ok(event) => {
                         log::trace!("Received completion event: {:?}", event);
                         this.update(cx, |this, cx| {
@@ -843,6 +1106,7 @@ impl Thread {
                     }
                 }
             }
+
             return Ok(tool_uses);
         }
     }
@@ -870,7 +1134,7 @@ impl Thread {
     fn handle_streamed_completion_event(
         &mut self,
         event: LanguageModelCompletionEvent,
-        event_stream: &AgentResponseEventStream,
+        event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
     ) -> Option<Task<LanguageModelToolResult>> {
         log::trace!("Handling streamed completion event: {:?}", event);
@@ -878,7 +1142,7 @@ impl Thread {
 
         match event {
             StartMessage { .. } => {
-                self.flush_pending_message();
+                self.flush_pending_message(cx);
                 self.pending_message = Some(AgentMessage::default());
             }
             Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
@@ -912,7 +1176,7 @@ impl Thread {
     fn handle_text_event(
         &mut self,
         new_text: String,
-        event_stream: &AgentResponseEventStream,
+        event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
     ) {
         event_stream.send_text(&new_text);
@@ -933,7 +1197,7 @@ impl Thread {
         &mut self,
         new_text: String,
         new_signature: Option<String>,
-        event_stream: &AgentResponseEventStream,
+        event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
     ) {
         event_stream.send_thinking(&new_text);
@@ -965,7 +1229,7 @@ impl Thread {
     fn handle_tool_use_event(
         &mut self,
         tool_use: LanguageModelToolUse,
-        event_stream: &AgentResponseEventStream,
+        event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
     ) -> Option<Task<LanguageModelToolResult>> {
         cx.notify();
@@ -1083,11 +1347,85 @@ impl Thread {
         }
     }
 
+    pub fn title(&self) -> SharedString {
+        self.title.clone().unwrap_or("New Thread".into())
+    }
+
+    fn update_title(
+        &mut self,
+        event_stream: &ThreadEventStream,
+        cx: &mut Context<Self>,
+    ) -> Option<Task<Result<()>>> {
+        if self.title.is_some() {
+            log::debug!("Skipping title generation because we already have one.");
+            return None;
+        }
+
+        log::info!(
+            "Generating title with model: {:?}",
+            self.summarization_model.as_ref().map(|model| model.name())
+        );
+        let model = self.summarization_model.clone()?;
+        let event_stream = event_stream.clone();
+        let mut request = LanguageModelRequest {
+            intent: Some(CompletionIntent::ThreadSummarization),
+            temperature: AgentSettings::temperature_for_model(&model, cx),
+            ..Default::default()
+        };
+
+        for message in &self.messages {
+            request.messages.extend(message.to_request());
+        }
+
+        request.messages.push(LanguageModelRequestMessage {
+            role: Role::User,
+            content: vec![SUMMARIZE_THREAD_PROMPT.into()],
+            cache: false,
+        });
+        Some(cx.spawn(async move |this, cx| {
+            let mut title = String::new();
+            let mut messages = model.stream_completion(request, cx).await?;
+            while let Some(event) = messages.next().await {
+                let event = event?;
+                let text = match event {
+                    LanguageModelCompletionEvent::Text(text) => text,
+                    LanguageModelCompletionEvent::StatusUpdate(
+                        CompletionRequestStatus::UsageUpdated { .. },
+                    ) => {
+                        // this.update(cx, |thread, cx| {
+                        //     thread.update_model_request_usage(amount as u32, limit, cx);
+                        // })?;
+                        // TODO: handle usage update
+                        continue;
+                    }
+                    _ => continue,
+                };
+
+                let mut lines = text.lines();
+                title.extend(lines.next());
+
+                // Stop if the LLM generated multiple lines.
+                if lines.next().is_some() {
+                    break;
+                }
+            }
+
+            log::info!("Setting title: {}", title);
+
+            this.update(cx, |this, cx| {
+                let title = SharedString::from(title);
+                event_stream.send_title_update(title.clone());
+                this.title = Some(title);
+                cx.notify();
+            })
+        }))
+    }
+
     fn pending_message(&mut self) -> &mut AgentMessage {
         self.pending_message.get_or_insert_default()
     }
 
-    fn flush_pending_message(&mut self) {
+    fn flush_pending_message(&mut self, cx: &mut Context<Self>) {
         let Some(mut message) = self.pending_message.take() else {
             return;
         };
@@ -1104,9 +1442,7 @@ impl Thread {
                         tool_use_id: tool_use.id.clone(),
                         tool_name: tool_use.name.clone(),
                         is_error: true,
-                        content: LanguageModelToolResultContent::Text(
-                            "Tool canceled by user".into(),
-                        ),
+                        content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()),
                         output: None,
                     },
                 );
@@ -1114,6 +1450,8 @@ impl Thread {
         }
 
         self.messages.push(Message::Agent(message));
+        self.updated_at = Utc::now();
+        cx.notify()
     }
 
     pub(crate) fn build_completion_request(
@@ -1205,15 +1543,7 @@ impl Thread {
         );
         let mut messages = vec![self.build_system_message(cx)];
         for message in &self.messages {
-            match message {
-                Message::User(message) => messages.push(message.to_request()),
-                Message::Agent(message) => messages.extend(message.to_request()),
-                Message::Resume => messages.push(LanguageModelRequestMessage {
-                    role: Role::User,
-                    content: vec!["Continue where you left off".into()],
-                    cache: false,
-                }),
-            }
+            messages.extend(message.to_request());
         }
 
         if let Some(message) = self.pending_message.as_ref() {
@@ -1367,7 +1697,7 @@ struct RunningTurn {
     _task: Task<()>,
     /// The current event stream for the running turn. Used to report a final
     /// cancellation event if we cancel the turn.
-    event_stream: AgentResponseEventStream,
+    event_stream: ThreadEventStream,
 }
 
 impl RunningTurn {
@@ -1420,6 +1750,17 @@ where
         cx: &mut App,
     ) -> Task<Result<Self::Output>>;
 
+    /// Emits events for a previous execution of the tool.
+    fn replay(
+        &self,
+        _input: Self::Input,
+        _output: Self::Output,
+        _event_stream: ToolCallEventStream,
+        _cx: &mut App,
+    ) -> Result<()> {
+        Ok(())
+    }
+
     fn erase(self) -> Arc<dyn AnyAgentTool> {
         Arc::new(Erased(Arc::new(self)))
     }
@@ -1447,6 +1788,13 @@ pub trait AnyAgentTool {
         event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<AgentToolOutput>>;
+    fn replay(
+        &self,
+        input: serde_json::Value,
+        output: serde_json::Value,
+        event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Result<()>;
 }
 
 impl<T> AnyAgentTool for Erased<Arc<T>>
@@ -1498,21 +1846,45 @@ where
             })
         })
     }
+
+    fn replay(
+        &self,
+        input: serde_json::Value,
+        output: serde_json::Value,
+        event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Result<()> {
+        let input = serde_json::from_value(input)?;
+        let output = serde_json::from_value(output)?;
+        self.0.replay(input, output, event_stream, cx)
+    }
 }
 
 #[derive(Clone)]
-struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
+struct ThreadEventStream(mpsc::UnboundedSender<Result<ThreadEvent>>);
+
+impl ThreadEventStream {
+    fn send_title_update(&self, text: SharedString) {
+        self.0
+            .unbounded_send(Ok(ThreadEvent::TitleUpdate(text)))
+            .ok();
+    }
+
+    fn send_user_message(&self, message: &UserMessage) {
+        self.0
+            .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone())))
+            .ok();
+    }
 
-impl AgentResponseEventStream {
     fn send_text(&self, text: &str) {
         self.0
-            .unbounded_send(Ok(AgentResponseEvent::Text(text.to_string())))
+            .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string())))
             .ok();
     }
 
     fn send_thinking(&self, text: &str) {
         self.0
-            .unbounded_send(Ok(AgentResponseEvent::Thinking(text.to_string())))
+            .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string())))
             .ok();
     }
 
@@ -1524,7 +1896,7 @@ impl AgentResponseEventStream {
         input: serde_json::Value,
     ) {
         self.0
-            .unbounded_send(Ok(AgentResponseEvent::ToolCall(Self::initial_tool_call(
+            .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call(
                 id,
                 title.to_string(),
                 kind,
@@ -1557,7 +1929,7 @@ impl AgentResponseEventStream {
         fields: acp::ToolCallUpdateFields,
     ) {
         self.0
-            .unbounded_send(Ok(AgentResponseEvent::ToolCallUpdate(
+            .unbounded_send(Ok(ThreadEvent::ToolCallUpdate(
                 acp::ToolCallUpdate {
                     id: acp::ToolCallId(tool_use_id.to_string().into()),
                     fields,
@@ -1568,26 +1940,24 @@ impl AgentResponseEventStream {
     }
 
     fn send_retry(&self, status: acp_thread::RetryStatus) {
-        self.0
-            .unbounded_send(Ok(AgentResponseEvent::Retry(status)))
-            .ok();
+        self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
     }
 
     fn send_stop(&self, reason: StopReason) {
         match reason {
             StopReason::EndTurn => {
                 self.0
-                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::EndTurn)))
+                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
                     .ok();
             }
             StopReason::MaxTokens => {
                 self.0
-                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::MaxTokens)))
+                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
                     .ok();
             }
             StopReason::Refusal => {
                 self.0
-                    .unbounded_send(Ok(AgentResponseEvent::Stop(acp::StopReason::Refusal)))
+                    .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
                     .ok();
             }
             StopReason::ToolUse => {}

crates/agent2/src/tools/context_server_registry.rs 🔗

@@ -228,4 +228,14 @@ impl AnyAgentTool for ContextServerTool {
             })
         })
     }
+
+    fn replay(
+        &self,
+        _input: serde_json::Value,
+        _output: serde_json::Value,
+        _event_stream: ToolCallEventStream,
+        _cx: &mut App,
+    ) -> Result<()> {
+        Ok(())
+    }
 }

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -5,10 +5,10 @@ use anyhow::{Context as _, Result, anyhow};
 use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
 use cloud_llm_client::CompletionIntent;
 use collections::HashSet;
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 use indoc::formatdoc;
-use language::ToPoint;
 use language::language_settings::{self, FormatOnSave};
+use language::{LanguageRegistry, ToPoint};
 use language_model::LanguageModelToolResultContent;
 use paths;
 use project::lsp_store::{FormatTrigger, LspFormatTarget};
@@ -98,11 +98,13 @@ pub enum EditFileMode {
 
 #[derive(Debug, Serialize, Deserialize)]
 pub struct EditFileToolOutput {
+    #[serde(alias = "original_path")]
     input_path: PathBuf,
-    project_path: PathBuf,
     new_text: String,
     old_text: Arc<String>,
+    #[serde(default)]
     diff: String,
+    #[serde(alias = "raw_output")]
     edit_agent_output: EditAgentOutput,
 }
 
@@ -122,12 +124,16 @@ impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 }
 
 pub struct EditFileTool {
-    thread: Entity<Thread>,
+    thread: WeakEntity<Thread>,
+    language_registry: Arc<LanguageRegistry>,
 }
 
 impl EditFileTool {
-    pub fn new(thread: Entity<Thread>) -> Self {
-        Self { thread }
+    pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
+        Self {
+            thread,
+            language_registry,
+        }
     }
 
     fn authorize(
@@ -167,8 +173,11 @@ impl EditFileTool {
 
         // Check if path is inside the global config directory
         // First check if it's already inside project - if not, try to canonicalize
-        let thread = self.thread.read(cx);
-        let project_path = thread.project().read(cx).find_project_path(&input.path, cx);
+        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
+            thread.project().read(cx).find_project_path(&input.path, cx)
+        }) else {
+            return Task::ready(Err(anyhow!("thread was dropped")));
+        };
 
         // If the path is inside the project, and it's not one of the above edge cases,
         // then no confirmation is necessary. Otherwise, confirmation is necessary.
@@ -221,7 +230,12 @@ impl AgentTool for EditFileTool {
         event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<Self::Output>> {
-        let project = self.thread.read(cx).project().clone();
+        let Ok(project) = self
+            .thread
+            .read_with(cx, |thread, _cx| thread.project().clone())
+        else {
+            return Task::ready(Err(anyhow!("thread was dropped")));
+        };
         let project_path = match resolve_path(&input, project.clone(), cx) {
             Ok(path) => path,
             Err(err) => return Task::ready(Err(anyhow!(err))),
@@ -237,23 +251,17 @@ impl AgentTool for EditFileTool {
             });
         }
 
-        let Some(request) = self.thread.update(cx, |thread, cx| {
-            thread
-                .build_completion_request(CompletionIntent::ToolResults, cx)
-                .ok()
-        }) else {
-            return Task::ready(Err(anyhow!("Failed to build completion request")));
-        };
-        let thread = self.thread.read(cx);
-        let Some(model) = thread.model().cloned() else {
-            return Task::ready(Err(anyhow!("No language model configured")));
-        };
-        let action_log = thread.action_log().clone();
-
         let authorize = self.authorize(&input, &event_stream, cx);
         cx.spawn(async move |cx: &mut AsyncApp| {
             authorize.await?;
 
+            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
+                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
+                (request, thread.model().cloned(), thread.action_log().clone())
+            })?;
+            let request = request?;
+            let model = model.context("No language model configured")?;
+
             let edit_format = EditFormat::from_model(model.clone())?;
             let edit_agent = EditAgent::new(
                 model,
@@ -419,7 +427,6 @@ impl AgentTool for EditFileTool {
 
             Ok(EditFileToolOutput {
                 input_path: input.path,
-                project_path: project_path.path.to_path_buf(),
                 new_text: new_text.clone(),
                 old_text,
                 diff: unified_diff,
@@ -427,6 +434,25 @@ impl AgentTool for EditFileTool {
             })
         })
     }
+
+    fn replay(
+        &self,
+        _input: Self::Input,
+        output: Self::Output,
+        event_stream: ToolCallEventStream,
+        cx: &mut App,
+    ) -> Result<()> {
+        event_stream.update_diff(cx.new(|cx| {
+            Diff::finalized(
+                output.input_path,
+                Some(output.old_text.to_string()),
+                output.new_text,
+                self.language_registry.clone(),
+                cx,
+            )
+        }));
+        Ok(())
+    }
 }
 
 /// Validate that the file path is valid, meaning:
@@ -515,6 +541,7 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         fs.insert_tree("/root", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -527,6 +554,7 @@ mod tests {
                 action_log,
                 Templates::new(),
                 Some(model),
+                None,
                 cx,
             )
         });
@@ -537,7 +565,11 @@ mod tests {
                     path: "root/nonexistent_file.txt".into(),
                     mode: EditFileMode::Edit,
                 };
-                Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             })
             .await;
         assert_eq!(
@@ -724,6 +756,7 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
@@ -750,9 +783,10 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
+                Arc::new(EditFileTool::new(
+                    thread.downgrade(),
+                    language_registry.clone(),
+                ))
                 .run(input, ToolCallEventStream::test().0, cx)
             });
 
@@ -806,7 +840,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool { thread }).run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the unformatted content
@@ -850,6 +888,7 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
         let thread = cx.new(|cx| {
@@ -860,6 +899,7 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
@@ -887,9 +927,10 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
+                Arc::new(EditFileTool::new(
+                    thread.downgrade(),
+                    language_registry.clone(),
+                ))
                 .run(input, ToolCallEventStream::test().0, cx)
             });
 
@@ -938,10 +979,11 @@ mod tests {
                     path: "root/src/main.rs".into(),
                     mode: EditFileMode::Overwrite,
                 };
-                Arc::new(EditFileTool {
-                    thread: thread.clone(),
-                })
-                .run(input, ToolCallEventStream::test().0, cx)
+                Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
+                    input,
+                    ToolCallEventStream::test().0,
+                    cx,
+                )
             });
 
             // Stream the content with trailing whitespace
@@ -976,6 +1018,7 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
         let thread = cx.new(|cx| {
@@ -986,10 +1029,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
         fs.insert_tree("/root", json!({})).await;
 
         // Test 1: Path with .zed component should require confirmation
@@ -1111,6 +1155,7 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         fs.insert_tree("/project", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
@@ -1123,10 +1168,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test global config paths - these should require confirmation if they exist and are outside the project
         let test_cases = vec![
@@ -1220,7 +1266,7 @@ mod tests {
             cx,
         )
         .await;
-
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1233,10 +1279,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test files in different worktrees
         let test_cases = vec![
@@ -1302,6 +1349,7 @@ mod tests {
         )
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1314,10 +1362,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test edge cases
         let test_cases = vec![
@@ -1386,6 +1435,7 @@ mod tests {
         )
         .await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1398,10 +1448,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test different EditFileMode values
         let modes = vec![
@@ -1467,6 +1518,7 @@ mod tests {
         init_test(cx);
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let context_server_registry =
             cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -1479,10 +1531,11 @@ mod tests {
                 action_log.clone(),
                 Templates::new(),
                 Some(model.clone()),
+                None,
                 cx,
             )
         });
-        let tool = Arc::new(EditFileTool { thread });
+        let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         assert_eq!(
             tool.initial_title(Err(json!({

crates/agent2/src/tools/terminal_tool.rs 🔗

@@ -319,7 +319,7 @@ mod tests {
     use theme::ThemeSettings;
     use util::test::TempTree;
 
-    use crate::AgentResponseEvent;
+    use crate::ThreadEvent;
 
     use super::*;
 
@@ -396,7 +396,7 @@ mod tests {
             });
             cx.run_until_parked();
             let event = stream_rx.try_next();
-            if let Ok(Some(Ok(AgentResponseEvent::ToolCallAuthorization(auth)))) = event {
+            if let Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(auth)))) = event {
                 auth.response.send(auth.options[0].id.clone()).unwrap();
             }
 

crates/agent2/src/tools/web_search_tool.rs 🔗

@@ -80,33 +80,48 @@ impl AgentTool for WebSearchTool {
                 }
             };
 
-            let result_text = if response.results.len() == 1 {
-                "1 result".to_string()
-            } else {
-                format!("{} results", response.results.len())
-            };
-            event_stream.update_fields(acp::ToolCallUpdateFields {
-                title: Some(format!("Searched the web: {result_text}")),
-                content: Some(
-                    response
-                        .results
-                        .iter()
-                        .map(|result| acp::ToolCallContent::Content {
-                            content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
-                                name: result.title.clone(),
-                                uri: result.url.clone(),
-                                title: Some(result.title.clone()),
-                                description: Some(result.text.clone()),
-                                mime_type: None,
-                                annotations: None,
-                                size: None,
-                            }),
-                        })
-                        .collect(),
-                ),
-                ..Default::default()
-            });
+            emit_update(&response, &event_stream);
             Ok(WebSearchToolOutput(response))
         })
     }
+
+    fn replay(
+        &self,
+        _input: Self::Input,
+        output: Self::Output,
+        event_stream: ToolCallEventStream,
+        _cx: &mut App,
+    ) -> Result<()> {
+        emit_update(&output.0, &event_stream);
+        Ok(())
+    }
+}
+
+fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
+    let result_text = if response.results.len() == 1 {
+        "1 result".to_string()
+    } else {
+        format!("{} results", response.results.len())
+    };
+    event_stream.update_fields(acp::ToolCallUpdateFields {
+        title: Some(format!("Searched the web: {result_text}")),
+        content: Some(
+            response
+                .results
+                .iter()
+                .map(|result| acp::ToolCallContent::Content {
+                    content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                        name: result.title.clone(),
+                        uri: result.url.clone(),
+                        title: Some(result.title.clone()),
+                        description: Some(result.text.clone()),
+                        mime_type: None,
+                        annotations: None,
+                        size: None,
+                    }),
+                })
+                .collect(),
+        ),
+        ..Default::default()
+    });
 }

crates/agent_servers/Cargo.toml 🔗

@@ -18,6 +18,7 @@ doctest = false
 
 [dependencies]
 acp_thread.workspace = true
+action_log.workspace = true
 agent-client-protocol.workspace = true
 agent_settings.workspace = true
 agentic-coding-protocol.workspace = true

crates/agent_servers/src/acp/v0.rs 🔗

@@ -1,4 +1,5 @@
 // Translates old acp agents into the new schema
+use action_log::ActionLog;
 use agent_client_protocol as acp;
 use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
 use anyhow::{Context as _, Result, anyhow};
@@ -443,7 +444,8 @@ impl AgentConnection for AcpConnection {
             cx.update(|cx| {
                 let thread = cx.new(|cx| {
                     let session_id = acp::SessionId("acp-old-no-id".into());
-                    AcpThread::new(self.name, self.clone(), project, session_id, cx)
+                    let action_log = cx.new(|_| ActionLog::new(project.clone()));
+                    AcpThread::new(self.name, self.clone(), project, action_log, session_id)
                 });
                 current_thread.replace(thread.downgrade());
                 thread

crates/agent_servers/src/acp/v1.rs 🔗

@@ -1,3 +1,4 @@
+use action_log::ActionLog;
 use agent_client_protocol::{self as acp, Agent as _};
 use anyhow::anyhow;
 use collections::HashMap;
@@ -153,14 +154,14 @@ impl AgentConnection for AcpConnection {
                 })?;
 
             let session_id = response.session_id;
-
-            let thread = cx.new(|cx| {
+            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
+            let thread = cx.new(|_cx| {
                 AcpThread::new(
                     self.server_name,
                     self.clone(),
                     project,
+                    action_log,
                     session_id.clone(),
-                    cx,
                 )
             })?;
 

crates/agent_servers/src/claude.rs 🔗

@@ -1,6 +1,7 @@
 mod mcp_server;
 pub mod tools;
 
+use action_log::ActionLog;
 use collections::HashMap;
 use context_server::listener::McpServerTool;
 use language_models::provider::anthropic::AnthropicLanguageModelProvider;
@@ -215,8 +216,15 @@ impl AgentConnection for ClaudeAgentConnection {
                 }
             });
 
-            let thread = cx.new(|cx| {
-                AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
+            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
+            let thread = cx.new(|_cx| {
+                AcpThread::new(
+                    "Claude Code",
+                    self.clone(),
+                    project,
+                    action_log,
+                    session_id.clone(),
+                )
             })?;
 
             thread_tx.send(thread.downgrade())?;

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -303,8 +303,13 @@ impl AcpThreadView {
                         let action_log_subscription =
                             cx.observe(&action_log, |_, _, cx| cx.notify());
 
-                        this.list_state
-                            .splice(0..0, thread.read(cx).entries().len());
+                        let count = thread.read(cx).entries().len();
+                        this.list_state.splice(0..0, count);
+                        this.entry_view_state.update(cx, |view_state, cx| {
+                            for ix in 0..count {
+                                view_state.sync_entry(ix, &thread, window, cx);
+                            }
+                        });
 
                         AgentDiff::set_active_thread(&workspace, thread.clone(), window, cx);
 
@@ -808,6 +813,7 @@ impl AcpThreadView {
                 self.thread_retry_status.take();
                 self.thread_state = ThreadState::ServerExited { status: *status };
             }
+            AcpThreadEvent::TitleUpdated => {}
         }
         cx.notify();
     }
@@ -2816,12 +2822,15 @@ impl AcpThreadView {
             return;
         };
 
-        thread.update(cx, |thread, _cx| {
+        thread.update(cx, |thread, cx| {
             let current_mode = thread.completion_mode();
-            thread.set_completion_mode(match current_mode {
-                CompletionMode::Burn => CompletionMode::Normal,
-                CompletionMode::Normal => CompletionMode::Burn,
-            });
+            thread.set_completion_mode(
+                match current_mode {
+                    CompletionMode::Burn => CompletionMode::Normal,
+                    CompletionMode::Normal => CompletionMode::Burn,
+                },
+                cx,
+            );
         });
     }
 
@@ -3572,8 +3581,9 @@ impl AcpThreadView {
                                     ))
                                     .on_click({
                                         cx.listener(move |this, _, _window, cx| {
-                                            thread.update(cx, |thread, _cx| {
-                                                thread.set_completion_mode(CompletionMode::Burn);
+                                            thread.update(cx, |thread, cx| {
+                                                thread
+                                                    .set_completion_mode(CompletionMode::Burn, cx);
                                             });
                                             this.resume_chat(cx);
                                         })
@@ -4156,12 +4166,13 @@ pub(crate) mod tests {
             cx: &mut gpui::App,
         ) -> Task<gpui::Result<Entity<AcpThread>>> {
             Task::ready(Ok(cx.new(|cx| {
+                let action_log = cx.new(|_| ActionLog::new(project.clone()));
                 AcpThread::new(
                     "SaboteurAgentConnection",
                     self,
                     project,
+                    action_log,
                     SessionId("test".into()),
-                    cx,
                 )
             })))
         }

crates/agent_ui/src/agent_diff.rs 🔗

@@ -199,24 +199,21 @@ impl AgentDiffPane {
         let action_log = thread.action_log(cx).clone();
 
         let mut this = Self {
-            _subscriptions: [
-                Some(
-                    cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
-                        this.update_excerpts(window, cx)
-                    }),
-                ),
+            _subscriptions: vec![
+                cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
+                    this.update_excerpts(window, cx)
+                }),
                 match &thread {
-                    AgentDiffThread::Native(thread) => {
-                        Some(cx.subscribe(thread, |this, _thread, event, cx| {
-                            this.handle_thread_event(event, cx)
-                        }))
-                    }
-                    AgentDiffThread::AcpThread(_) => None,
+                    AgentDiffThread::Native(thread) => cx
+                        .subscribe(thread, |this, _thread, event, cx| {
+                            this.handle_native_thread_event(event, cx)
+                        }),
+                    AgentDiffThread::AcpThread(thread) => cx
+                        .subscribe(thread, |this, _thread, event, cx| {
+                            this.handle_acp_thread_event(event, cx)
+                        }),
                 },
-            ]
-            .into_iter()
-            .flatten()
-            .collect(),
+            ],
             title: SharedString::default(),
             multibuffer,
             editor,
@@ -324,13 +321,20 @@ impl AgentDiffPane {
         }
     }
 
-    fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
+    fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
         match event {
             ThreadEvent::SummaryGenerated => self.update_title(cx),
             _ => {}
         }
     }
 
+    fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
+        match event {
+            AcpThreadEvent::TitleUpdated => self.update_title(cx),
+            _ => {}
+        }
+    }
+
     pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
         if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
             self.editor.update(cx, |editor, cx| {
@@ -1523,7 +1527,8 @@ impl AgentDiff {
             AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::ServerExited(_) => {
                 self.update_reviewing_editors(workspace, window, cx);
             }
-            AcpThreadEvent::EntriesRemoved(_)
+            AcpThreadEvent::TitleUpdated
+            | AcpThreadEvent::EntriesRemoved(_)
             | AcpThreadEvent::ToolAuthorizationRequired
             | AcpThreadEvent::Retry(_) => {}
         }