Saving history with thread titles

Conrad Irwin created

Change summary

crates/acp_thread/src/acp_thread.rs       |   7 +
crates/agent2/src/agent.rs                |  60 +++++++--
crates/agent2/src/db.rs                   |   2 
crates/agent2/src/history_store.rs        |   3 
crates/agent2/src/tests/mod.rs            |   1 
crates/agent2/src/thread.rs               | 155 +++++++++++++++++++++++-
crates/agent2/src/tools/edit_file_tool.rs | 143 ++--------------------
crates/agent_ui/src/acp/thread_history.rs |  11 -
crates/agent_ui/src/acp/thread_view.rs    |   1 
crates/agent_ui/src/agent_diff.rs         |  41 +++--
10 files changed, 241 insertions(+), 183 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -691,6 +691,7 @@ pub struct AcpThread {
 
 pub enum AcpThreadEvent {
     NewEntry,
+    TitleUpdated,
     EntryUpdated(usize),
     EntriesRemoved(Range<usize>),
     ToolAuthorizationRequired,
@@ -934,6 +935,12 @@ impl AcpThread {
         cx.emit(AcpThreadEvent::NewEntry);
     }
 
+    pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
+        self.title = title;
+        cx.emit(AcpThreadEvent::TitleUpdated);
+        Ok(())
+    }
+
     pub fn update_tool_call(
         &mut self,
         update: impl Into<ToolCallUpdate>,

crates/agent2/src/agent.rs 🔗

@@ -255,6 +255,9 @@ impl NativeAgent {
                         this.sessions.remove(acp_thread.session_id());
                     }),
                     cx.observe(&thread, |this, thread, cx| {
+                        thread.update(cx, |thread, cx| {
+                            thread.generate_title_if_needed(cx);
+                        });
                         this.save_thread(thread.clone(), cx)
                     }),
                 ],
@@ -262,13 +265,14 @@ impl NativeAgent {
         );
     }
 
-    fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
-        let id = thread.read(cx).id().clone();
+    fn save_thread(&mut self, thread_handle: Entity<Thread>, cx: &mut Context<Self>) {
+        let thread = thread_handle.read(cx);
+        let id = thread.id().clone();
         let Some(session) = self.sessions.get_mut(&id) else {
             return;
         };
 
-        let thread = thread.downgrade();
+        let thread = thread_handle.downgrade();
         let thread_database = self.thread_database.clone();
         session.save_task = cx.spawn(async move |this, cx| {
             cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
@@ -507,7 +511,7 @@ impl NativeAgent {
 
     fn handle_models_updated_event(
         &mut self,
-        _registry: Entity<LanguageModelRegistry>,
+        registry: Entity<LanguageModelRegistry>,
         _event: &language_model::Event,
         cx: &mut Context<Self>,
     ) {
@@ -518,6 +522,11 @@ impl NativeAgent {
                 if let Some(model) = self.models.model_from_id(&model_id) {
                     thread.set_model(model.clone(), cx);
                 }
+                let summarization_model = registry
+                    .read(cx)
+                    .thread_summary_model()
+                    .map(|model| model.model.clone());
+                thread.set_summarization_model(summarization_model, cx);
             });
         }
     }
@@ -641,6 +650,10 @@ impl NativeAgentConnection {
                                     thread.update_tool_call(update, cx)
                                 })??;
                             }
+                            ThreadEvent::TitleUpdate(title) => {
+                                acp_thread
+                                    .update(cx, |thread, cx| thread.update_title(title, cx))??;
+                            }
                             ThreadEvent::Stop(stop_reason) => {
                                 log::debug!("Assistant message complete: {:?}", stop_reason);
                                 return Ok(acp::PromptResponse { stop_reason });
@@ -821,6 +834,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                             )
                         })?;
 
+                    let summarization_model = registry.thread_summary_model().map(|c| c.model);
+
                     let thread = cx.new(|cx| {
                         let mut thread = Thread::new(
                             session_id.clone(),
@@ -830,6 +845,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                             action_log.clone(),
                             agent.templates.clone(),
                             default_model,
+                            summarization_model,
                             cx,
                         );
                         Self::register_tools(&mut thread, project, action_log, cx);
@@ -894,7 +910,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
             // Create Thread
             let thread = agent.update(cx, |agent, cx| {
-                let configured_model = LanguageModelRegistry::global(cx)
+                let language_model_registry = LanguageModelRegistry::global(cx);
+                let configured_model = language_model_registry
                     .update(cx, |registry, cx| {
                         db_thread
                             .model
@@ -915,6 +932,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                     .model_from_id(&LanguageModels::model_id(&configured_model.model))
                     .context("no model by id")?;
 
+                let summarization_model = language_model_registry
+                    .read(cx)
+                    .thread_summary_model()
+                    .map(|c| c.model);
+
                 let thread = cx.new(|cx| {
                     let mut thread = Thread::from_db(
                         session_id,
@@ -925,6 +947,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                         action_log.clone(),
                         agent.templates.clone(),
                         model,
+                        summarization_model,
                         cx,
                     );
                     Self::register_tools(&mut thread, project, action_log, cx);
@@ -1047,12 +1070,13 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 
 #[cfg(test)]
 mod tests {
-    use crate::{HistoryEntry, HistoryStore};
+    use crate::HistoryStore;
 
     use super::*;
     use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
     use fs::FakeFs;
     use gpui::TestAppContext;
+    use language_model::fake_provider::FakeLanguageModel;
     use serde_json::json;
     use settings::SettingsStore;
     use util::path;
@@ -1245,13 +1269,6 @@ mod tests {
         )
         .await
         .unwrap();
-        let model = cx.update(|cx| {
-            LanguageModelRegistry::global(cx)
-                .read(cx)
-                .default_model()
-                .unwrap()
-                .model
-        });
         let connection = NativeAgentConnection(agent.clone());
         let history_store = cx.new(|cx| {
             let mut store = HistoryStore::new(cx);
@@ -1268,6 +1285,16 @@ mod tests {
         let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
         let selector = connection.model_selector().unwrap();
 
+        let summarization_model: Arc<dyn LanguageModel> =
+            Arc::new(FakeLanguageModel::default()) as _;
+
+        agent.update(cx, |agent, cx| {
+            let thread = agent.sessions.get(&session_id).unwrap().thread.clone();
+            thread.update(cx, |thread, cx| {
+                thread.set_summarization_model(Some(summarization_model.clone()), cx);
+            })
+        });
+
         let model = cx
             .update(|cx| selector.selected_model(&session_id, cx))
             .await
@@ -1283,11 +1310,16 @@ mod tests {
         model.send_last_completion_stream_text_chunk("Hey");
         model.end_last_completion_stream();
         send.await.unwrap();
+
+        summarization_model
+            .as_fake()
+            .send_last_completion_stream_text_chunk("Saying Hello");
+        summarization_model.as_fake().end_last_completion_stream();
         cx.executor().advance_clock(SAVE_THREAD_DEBOUNCE);
 
         let history = history_store.update(cx, |store, cx| store.entries(cx));
         assert_eq!(history.len(), 1);
-        assert_eq!(history[0].title(), "Hi");
+        assert_eq!(history[0].title(), "Saying Hello");
     }
 
     fn init_test(cx: &mut TestAppContext) {

crates/agent2/src/db.rs 🔗

@@ -386,8 +386,6 @@ impl ThreadsDatabase {
 
 #[cfg(test)]
 mod tests {
-    use crate::NativeAgent;
-    use crate::Templates;
 
     use super::*;
     use agent::MessageSegment;

crates/agent2/src/history_store.rs 🔗

@@ -1,12 +1,11 @@
 use acp_thread::{AcpThreadMetadata, AgentConnection, AgentServerName};
 use agent_client_protocol as acp;
-use anyhow::{Context as _, Result};
 use assistant_context::SavedContextMetadata;
 use chrono::{DateTime, Utc};
 use collections::HashMap;
 use gpui::{SharedString, Task, prelude::*};
 use serde::{Deserialize, Serialize};
-use smol::stream::StreamExt;
+
 use std::{path::Path, sync::Arc, time::Duration};
 
 const MAX_RECENTLY_OPENED_ENTRIES: usize = 6;

crates/agent2/src/tests/mod.rs 🔗

@@ -1506,6 +1506,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
             action_log,
             templates,
             model.clone(),
+            None,
             cx,
         )
     });

crates/agent2/src/thread.rs 🔗

@@ -5,7 +5,7 @@ use acp_thread::{MentionUri, UserMessageId};
 use action_log::ActionLog;
 use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot};
 use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_PROMPT};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use chrono::{DateTime, Utc};
@@ -24,7 +24,7 @@ use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
     LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
     LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
-    LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage,
+    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, StopReason, TokenUsage,
 };
 use project::{
     Project,
@@ -75,6 +75,18 @@ impl Message {
         }
     }
 
+    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
+        match self {
+            Message::User(message) => vec![message.to_request()],
+            Message::Agent(message) => message.to_request(),
+            Message::Resume => vec![LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Continue where you left off".into()],
+                cache: false,
+            }],
+        }
+    }
+
     pub fn to_markdown(&self) -> String {
         match self {
             Message::User(message) => message.to_markdown(),
@@ -82,6 +94,13 @@ impl Message {
             Message::Resume => "[resumed after tool use limit was reached]".into(),
         }
     }
+
+    pub fn role(&self) -> Role {
+        match self {
+            Message::User(_) | Message::Resume => Role::User,
+            Message::Agent(_) => Role::Assistant,
+        }
+    }
 }
 
 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -426,6 +445,7 @@ pub enum ThreadEvent {
     ToolCall(acp::ToolCall),
     ToolCallUpdate(acp_thread::ToolCallUpdate),
     ToolCallAuthorization(ToolCallAuthorization),
+    TitleUpdate(SharedString),
     Stop(acp::StopReason),
 }
 
@@ -475,6 +495,7 @@ pub struct Thread {
     project_context: Rc<RefCell<ProjectContext>>,
     templates: Arc<Templates>,
     model: Arc<dyn LanguageModel>,
+    summarization_model: Option<Arc<dyn LanguageModel>>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
 }
@@ -488,6 +509,7 @@ impl Thread {
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
         model: Arc<dyn LanguageModel>,
+        summarization_model: Option<Arc<dyn LanguageModel>>,
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -516,11 +538,37 @@ impl Thread {
             project_context,
             templates,
             model,
+            summarization_model,
             project,
             action_log,
         }
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn test(
+        model: Arc<dyn LanguageModel>,
+        project: Entity<Project>,
+        action_log: Entity<ActionLog>,
+        cx: &mut Context<Self>,
+    ) -> Self {
+        use crate::generate_session_id;
+
+        let context_server_registry =
+            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+
+        Self::new(
+            generate_session_id(),
+            project,
+            Rc::default(),
+            context_server_registry,
+            action_log,
+            Templates::new(),
+            model,
+            None,
+            cx,
+        )
+    }
+
     pub fn id(&self) -> &acp::SessionId {
         &self.id
     }
@@ -534,6 +582,7 @@ impl Thread {
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
         model: Arc<dyn LanguageModel>,
+        summarization_model: Option<Arc<dyn LanguageModel>>,
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = db_thread
@@ -558,6 +607,7 @@ impl Thread {
             project_context,
             templates,
             model,
+            summarization_model,
             project,
             action_log,
             updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list)
@@ -807,6 +857,15 @@ impl Thread {
         cx.notify()
     }
 
+    pub fn set_summarization_model(
+        &mut self,
+        model: Option<Arc<dyn LanguageModel>>,
+        cx: &mut Context<Self>,
+    ) {
+        self.summarization_model = model;
+        cx.notify()
+    }
+
     pub fn completion_mode(&self) -> CompletionMode {
         self.completion_mode
     }
@@ -1018,6 +1077,86 @@ impl Thread {
         events_rx
     }
 
+    pub fn generate_title_if_needed(&mut self, cx: &mut Context<Self>) {
+        if !matches!(self.title, ThreadTitle::None) {
+            return;
+        }
+
+        // todo!() copy logic from agent1 re: tool calls, etc.?
+        if self.messages.len() < 2 {
+            return;
+        }
+
+        self.generate_title(cx);
+    }
+
+    fn generate_title(&mut self, cx: &mut Context<Self>) {
+        let Some(model) = self.summarization_model.clone() else {
+            println!("No thread summary model");
+            return;
+        };
+        let mut request = LanguageModelRequest {
+            intent: Some(CompletionIntent::ThreadSummarization),
+            temperature: AgentSettings::temperature_for_model(&model, cx),
+            ..Default::default()
+        };
+
+        for message in &self.messages {
+            request.messages.extend(message.to_request());
+        }
+
+        request.messages.push(LanguageModelRequestMessage {
+            role: Role::User,
+            content: vec![MessageContent::Text(SUMMARIZE_THREAD_PROMPT.into())],
+            cache: false,
+        });
+
+        let task = cx.spawn(async move |this, cx| {
+            let result = async {
+                let mut messages = model.stream_completion(request, &cx).await?;
+
+                let mut new_summary = String::new();
+                while let Some(event) = messages.next().await {
+                    let Ok(event) = event else {
+                        continue;
+                    };
+                    let text = match event {
+                        LanguageModelCompletionEvent::Text(text) => text,
+                        LanguageModelCompletionEvent::StatusUpdate(
+                            CompletionRequestStatus::UsageUpdated { .. },
+                        ) => {
+                            // this.update(cx, |thread, cx| {
+                            //     thread.update_model_request_usage(amount as u32, limit, cx);
+                            // })?;
+                            // todo!()? not sure if this is the right place to do this.
+                            continue;
+                        }
+                        _ => continue,
+                    };
+
+                    let mut lines = text.lines();
+                    new_summary.extend(lines.next());
+
+                    // Stop if the LLM generated multiple lines.
+                    if lines.next().is_some() {
+                        break;
+                    }
+                }
+
+                anyhow::Ok(new_summary.into())
+            }
+            .await;
+
+            this.update(cx, |this, cx| {
+                this.title = ThreadTitle::Done(result);
+                cx.notify();
+            })
+            .log_err();
+        });
+
+        self.title = ThreadTitle::Pending(task);
+    }
+
     pub fn build_system_message(&self) -> LanguageModelRequestMessage {
         log::debug!("Building system message");
         let prompt = SystemPromptTemplate {
@@ -1373,15 +1512,7 @@ impl Thread {
         );
         let mut messages = vec![self.build_system_message()];
         for message in &self.messages {
-            match message {
-                Message::User(message) => messages.push(message.to_request()),
-                Message::Agent(message) => messages.extend(message.to_request()),
-                Message::Resume => messages.push(LanguageModelRequestMessage {
-                    role: Role::User,
-                    content: vec!["Continue where you left off".into()],
-                    cache: false,
-                }),
-            }
+            messages.extend(message.to_request());
         }
 
         if let Some(message) = self.pending_message.as_ref() {
@@ -1924,7 +2055,7 @@ impl From<UserMessageContent> for acp::ContentBlock {
                 annotations: None,
                 uri: None,
             }),
-            UserMessageContent::Mention { uri, content } => {
+            UserMessageContent::Mention { .. } => {
                 todo!()
             }
         }

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -521,7 +521,6 @@ fn resolve_path(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::{ContextServerRegistry, Templates, generate_session_id};
     use action_log::ActionLog;
     use client::TelemetrySettings;
     use fs::Fs;
@@ -529,7 +528,6 @@ mod tests {
     use language_model::fake_provider::FakeLanguageModel;
     use serde_json::json;
     use settings::SettingsStore;
-    use std::rc::Rc;
     use util::path;
 
     #[gpui::test]
@@ -541,21 +539,8 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project,
-                Rc::default(),
-                context_server_registry,
-                action_log,
-                Templates::new(),
-                model,
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
         let result = cx
             .update(|cx| {
                 let input = EditFileToolInput {
@@ -743,21 +728,8 @@ mod tests {
         });
 
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project,
-                Rc::default(),
-                context_server_registry,
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log.clone(), cx));
 
         // First, test with format_on_save enabled
         cx.update(|cx| {
@@ -885,22 +857,9 @@ mod tests {
 
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project,
-                Rc::default(),
-                context_server_registry,
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model.clone(), project, action_log, cx));
 
         // First, test with remove_trailing_whitespace_on_save enabled
         cx.update(|cx| {
@@ -1015,22 +974,10 @@ mod tests {
         let fs = project::FakeFs::new(cx.executor());
         let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project,
-                Rc::default(),
-                context_server_registry,
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
         fs.insert_tree("/root", json!({})).await;
 
@@ -1154,22 +1101,10 @@ mod tests {
         fs.insert_tree("/project", json!({})).await;
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project,
-                Rc::default(),
-                context_server_registry,
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test global config paths - these should require confirmation if they exist and are outside the project
@@ -1266,21 +1201,9 @@ mod tests {
         .await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project.clone(),
-                Rc::default(),
-                context_server_registry.clone(),
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test files in different worktrees
@@ -1349,21 +1272,9 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project.clone(),
-                Rc::default(),
-                context_server_registry.clone(),
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test edge cases
@@ -1435,21 +1346,9 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project.clone(),
-                Rc::default(),
-                context_server_registry.clone(),
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         // Test different EditFileMode values
@@ -1518,21 +1417,9 @@ mod tests {
         let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
         let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
         let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
         let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            Thread::new(
-                generate_session_id(),
-                project.clone(),
-                Rc::default(),
-                context_server_registry,
-                action_log.clone(),
-                Templates::new(),
-                model.clone(),
-                cx,
-            )
-        });
+        let thread = cx.new(|cx| Thread::test(model, project, action_log, cx));
+
         let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
 
         assert_eq!(

crates/agent_ui/src/acp/thread_history.rs 🔗

@@ -1,15 +1,12 @@
-use crate::{AgentPanel, RemoveSelectedThread};
+use crate::RemoveSelectedThread;
 use agent_servers::AgentServer;
-use agent2::{
-    NativeAgentServer,
-    history_store::{HistoryEntry, HistoryStore},
-};
+use agent2::{HistoryEntry, HistoryStore, NativeAgentServer};
 use chrono::{Datelike as _, Local, NaiveDate, TimeDelta};
 use editor::{Editor, EditorEvent};
 use fuzzy::{StringMatch, StringMatchCandidate};
 use gpui::{
     App, Empty, Entity, EventEmitter, FocusHandle, Focusable, ScrollStrategy, Stateful, Task,
-    UniformListScrollHandle, WeakEntity, Window, uniform_list,
+    UniformListScrollHandle, Window, uniform_list,
 };
 use project::Project;
 use std::{fmt::Display, ops::Range, sync::Arc};
@@ -72,7 +69,7 @@ impl AcpThreadHistory {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
-        let history_store = cx.new(|cx| agent2::history_store::HistoryStore::new(cx));
+        let history_store = cx.new(|cx| agent2::HistoryStore::new(cx));
 
         let agent = NativeAgentServer::new(project.read(cx).fs().clone());
 

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -687,6 +687,7 @@ impl AcpThreadView {
             AcpThreadEvent::ServerExited(status) => {
                 self.thread_state = ThreadState::ServerExited { status: *status };
             }
+            AcpThreadEvent::TitleUpdated => {}
         }
         cx.notify();
     }

crates/agent_ui/src/agent_diff.rs 🔗

@@ -199,24 +199,21 @@ impl AgentDiffPane {
         let action_log = thread.action_log(cx).clone();
 
         let mut this = Self {
-            _subscriptions: [
-                Some(
-                    cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
-                        this.update_excerpts(window, cx)
-                    }),
-                ),
+            _subscriptions: vec![
+                cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
+                    this.update_excerpts(window, cx)
+                }),
                 match &thread {
-                    AgentDiffThread::Native(thread) => {
-                        Some(cx.subscribe(&thread, |this, _thread, event, cx| {
-                            this.handle_thread_event(event, cx)
-                        }))
-                    }
-                    AgentDiffThread::AcpThread(_) => None,
+                    AgentDiffThread::Native(thread) => cx
+                        .subscribe(&thread, |this, _thread, event, cx| {
+                            this.handle_native_thread_event(event, cx)
+                        }),
+                    AgentDiffThread::AcpThread(thread) => cx
+                        .subscribe(&thread, |this, _thread, event, cx| {
+                            this.handle_acp_thread_event(event, cx)
+                        }),
                 },
-            ]
-            .into_iter()
-            .flatten()
-            .collect(),
+            ],
             title: SharedString::default(),
             multibuffer,
             editor,
@@ -324,13 +321,20 @@ impl AgentDiffPane {
         }
     }
 
-    fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
+    fn handle_native_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
         match event {
             ThreadEvent::SummaryGenerated => self.update_title(cx),
             _ => {}
         }
     }
 
+    fn handle_acp_thread_event(&mut self, event: &AcpThreadEvent, cx: &mut Context<Self>) {
+        match event {
+            AcpThreadEvent::TitleUpdated => self.update_title(cx),
+            _ => {}
+        }
+    }
+
     pub fn move_to_path(&self, path_key: PathKey, window: &mut Window, cx: &mut App) {
         if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
             self.editor.update(cx, |editor, cx| {
@@ -1521,7 +1525,8 @@ impl AgentDiff {
                     self.update_reviewing_editors(workspace, window, cx);
                 }
             }
-            AcpThreadEvent::EntriesRemoved(_)
+            AcpThreadEvent::TitleUpdated
+            | AcpThreadEvent::EntriesRemoved(_)
             | AcpThreadEvent::Stopped
             | AcpThreadEvent::ToolAuthorizationRequired
             | AcpThreadEvent::Error