From d83210d978b7bb1fe11925c0cc772d405f79e003 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 18 Aug 2025 17:30:12 +0200 Subject: [PATCH] WIP Co-authored-by: Conrad Irwin --- Cargo.lock | 1 + crates/acp_thread/src/connection.rs | 5 +- crates/agent2/Cargo.toml | 2 + crates/agent2/src/agent.rs | 151 ++++++++------ crates/agent2/src/agent2.rs | 6 + crates/agent2/src/tests/mod.rs | 5 +- crates/agent2/src/thread.rs | 242 ++++++++++++++++++---- crates/agent2/src/tools/edit_file_tool.rs | 11 +- 8 files changed, 319 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5479fddab3011b48d7bc90ad6246572c18e23e2c..41411110730160891996879ce842b66d909fd1bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,6 +211,7 @@ dependencies = [ "env_logger 0.11.8", "fs", "futures 0.3.31", + "git", "gpui", "gpui_tokio", "handlebars 4.5.0", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index fe6d3169bd9a2d3c4008e59ee4e5aed7c12bc82a..398222a831a199e264f4ed57a433160e8421ff13 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -27,7 +27,10 @@ pub trait AgentConnection { cx: &mut App, ) -> Task>>; - fn list_threads(&self, _cx: &mut App) -> Option>> { + fn list_threads( + &self, + _cx: &mut App, + ) -> Option>>> { return None; } diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml index ed487e2e22e524ada03e075f8135d9be6450ecc1..88da8759306a4b427670824c68893694900a333f 100644 --- a/crates/agent2/Cargo.toml +++ b/crates/agent2/Cargo.toml @@ -27,6 +27,7 @@ collections.workspace = true context_server.workspace = true fs.workspace = true futures.workspace = true +git.workspace = true gpui.workspace = true handlebars = { workspace = true, features = ["rust-embed"] } html_to_markdown.workspace = true @@ -72,6 +73,7 @@ context_server = { workspace = true, "features" = ["test-support"] } editor = { workspace = true, "features" = ["test-support"] } env_logger.workspace = true fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 398054739e48c6ef45a1fa6ac21b461e1758458a..403a59e51b8754d1cb28564db1a9bd1b4aac2428 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -5,7 +5,7 @@ use crate::{ OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates, }; -use crate::{DbThread, ThreadId, ThreadsDatabase}; +use crate::{DbThread, ThreadId, ThreadsDatabase, generate_session_id}; use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector}; use agent_client_protocol as acp; use agent_settings::AgentSettings; @@ -44,6 +44,8 @@ const RULES_FILE_NAMES: [&'static str; 9] = [ "GEMINI.md", ]; +const SAVE_THREAD_DEBOUNCE: Duration = Duration::from_millis(500); + pub struct RulesLoadingError { pub message: SharedString, } @@ -54,7 +56,8 @@ struct Session { thread: Entity, /// The ACP thread that handles protocol communication acp_thread: WeakEntity, - _subscription: Subscription, + save_task: Task>, + _subscriptions: Vec, } pub struct LanguageModels { @@ -169,8 +172,9 @@ pub struct NativeAgent { models: LanguageModels, project: Entity, prompt_store: Option>, - thread_database: Shared, Arc>>>, - history_listeners: Vec>>, + thread_database: Arc, + history: watch::Sender>>, + load_history: Task>, fs: Arc, _subscriptions: Vec, } @@ -189,6 +193,11 @@ impl NativeAgent { .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? .await; + let thread_database = cx + .update(|cx| ThreadsDatabase::connect(cx))? + .await + .map_err(|e| anyhow!(e))?; + cx.new(|cx| { let mut subscriptions = vec![ cx.subscribe(&project, Self::handle_project_event), @@ -203,7 +212,7 @@ impl NativeAgent { let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = watch::channel(()); - Self { + let this = Self { sessions: HashMap::new(), project_context: Rc::new(RefCell::new(project_context)), project_context_needs_refresh: project_context_needs_refresh_tx, @@ -213,18 +222,85 @@ impl NativeAgent { context_server_registry: cx.new(|cx| { ContextServerRegistry::new(project.read(cx).context_server_store(), cx) }), - thread_database: ThreadsDatabase::connect(cx), + thread_database, templates, models: LanguageModels::new(cx), project, prompt_store, fs, - history_listeners: Vec::new(), + history: watch::channel(None).0, + load_history: Task::ready(Ok(())), _subscriptions: subscriptions, - } + }; + this.reload_history(cx); + this }) } + pub fn insert_session( + &mut self, + thread: Entity, + acp_thread: Entity, + cx: &mut Context, + ) { + let id = thread.read(cx).id().clone(); + self.sessions.insert( + id, + Session { + thread: thread.clone(), + acp_thread: acp_thread.downgrade(), + save_task: Task::ready(()), + _subscriptions: vec![ + cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + cx.observe(&thread, |this, thread, cx| { + this.save_thread(thread.clone(), cx) + }), + ], + }, + ); + } + + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + let id = thread.read(cx).id().clone(); + let Some(session) = self.sessions.get_mut(&id) else { + return; + }; + + let thread = thread.downgrade(); + let thread_database = self.thread_database.clone(); + session.save_task = cx.spawn(async move |this, cx| { + cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; + thread_database.save_thread(id, db_thread).await?; + this.update(cx, |this, cx| this.reload_history(cx))?; + Ok(()) + }); + } + + fn reload_history(&mut self, cx: &mut Context) { + let thread_database = self.thread_database.clone(); + self.load_history = cx.spawn(async move |this, cx| { + let results = cx + .background_spawn(async move { + let results = thread_database.list_threads().await?; + Ok(results + .into_iter() + .map(|thread| AcpThreadMetadata { + agent: NATIVE_AGENT_SERVER_NAME.clone(), + id: thread.id.into(), + title: thread.title, + updated_at: thread.updated_at, + }) + .collect()) + }) + .await?; + this.update(cx, |this, cx| this.history.send(Some(results)))?; + anyhow::Ok(()) + }); + } + pub fn models(&self) -> &LanguageModels { &self.models } @@ -699,7 +775,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Starting thread creation in async context"); // Generate session ID - let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into()); + let session_id = generate_session_id(); log::info!("Created session with ID: {}", session_id); // Create AcpThread @@ -743,6 +819,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { let thread = cx.new(|cx| { let mut thread = Thread::new( + session_id.clone(), project.clone(), agent.project_context.clone(), agent.context_server_registry.clone(), @@ -761,16 +838,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread, - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); + agent.insert_session(thread, acp_thread.clone(), cx) })?; Ok(acp_thread) @@ -785,35 +853,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection { Task::ready(Ok(())) } - fn list_threads(&self, cx: &mut App) -> Option>> { - let (mut tx, rx) = futures::channel::mpsc::unbounded(); - let database = self.0.update(cx, |this, _| { - this.history_listeners.push(tx.clone()); - this.thread_database.clone() - }); - cx.background_executor() - .spawn(async move { - dbg!("listing!"); - let database = database.await.map_err(|e| anyhow!(e))?; - let results = database.list_threads().await?; - - dbg!(&results); - tx.send( - results - .into_iter() - .map(|thread| AcpThreadMetadata { - agent: NATIVE_AGENT_SERVER_NAME.clone(), - id: thread.id.into(), - title: thread.title, - updated_at: thread.updated_at, - }) - .collect(), - ) - .await?; - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - Some(rx) + fn list_threads( + &self, + cx: &mut App, + ) -> Option>>> { + Some(self.0.read(cx).history.receiver()) } fn load_thread( @@ -890,16 +934,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection { // Store the session agent.update(cx, |agent, cx| { - agent.sessions.insert( - session_id, - Session { - thread: thread.clone(), - acp_thread: acp_thread.downgrade(), - _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - }, - ); + agent.insert_session(session_id, thread, acp_thread, cx) })?; let events = thread.update(cx, |thread, cx| thread.replay(cx))?; diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs index 1813fe188083cbc7c7251ba0d0f7f60efea5cd51..eee9810cefd590afff463bd1634eb61d593ba2b0 100644 --- a/crates/agent2/src/agent2.rs +++ b/crates/agent2/src/agent2.rs @@ -15,3 +15,9 @@ pub use native_agent_server::NativeAgentServer; pub use templates::*; pub use thread::*; pub use tools::*; + +use agent_client_protocol as acp; + +pub fn generate_session_id() -> acp::SessionId { + acp::SessionId(uuid::Uuid::new_v4().to_string().into()) +} diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 5cbddafded88e1373200317dd9342686232c7298..75a21a2baa72b7985709bfdb24b980a38aa0825a 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -709,9 +709,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { ); } -async fn expect_tool_call( - events: &mut UnboundedReceiver>, -) -> acp::ToolCall { +async fn expect_tool_call(events: &mut UnboundedReceiver>) -> acp::ToolCall { let event = events .next() .await @@ -1501,6 +1499,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { let action_log = cx.new(|_| ActionLog::new(project.clone())); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, project_context.clone(), context_server_registry, diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 034b26b714e6ec62bff371fe00bd9d176bca1ec7..ec820c7b5f02dee3003e90cf857535283f3ff0fb 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -1,25 +1,35 @@ -use crate::{ContextServerRegistry, DbThread, SystemPromptTemplate, Template, Templates}; +use crate::{ + ContextServerRegistry, DbLanguageModel, DbThread, SystemPromptTemplate, Template, Templates, +}; use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; +use agent::thread::{DetailedSummaryState, GitState, ProjectSnapshot, WorktreeSnapshot}; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::adapt_schema_to_format; +use chrono::{DateTime, Utc}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus}; use collections::IndexMap; use fs::Fs; use futures::{ + FutureExt, channel::{mpsc, oneshot}, + future::Shared, stream::FuturesUnordered, }; +use git::repository::DiffType; use gpui::{App, AppContext, Context, Entity, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, - LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, + LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason, TokenUsage, +}; +use project::{ + Project, + git_store::{GitStore, RepositoryState}, }; -use project::Project; use prompt_store::ProjectContext; use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; @@ -32,41 +42,6 @@ use uuid::Uuid; const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(pub(crate) Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} - -impl From for ThreadId { - fn from(value: acp::SessionId) -> Self { - Self(value.0) - } -} - -impl From for acp::SessionId { - fn from(value: ThreadId) -> Self { - Self(value.0) - } -} - /// The ID of the user prompt that initiated a request. /// /// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). @@ -461,9 +436,28 @@ pub struct ToolCallAuthorization { pub response: oneshot::Sender, } +enum ThreadTitle { + None, + Pending(Task<()>), + Done(Result), +} + +impl ThreadTitle { + pub fn unwrap_or_default(&self) -> SharedString { + if let ThreadTitle::Done(Ok(title)) = self { + title.clone() + } else { + "New Thread".into() + } + } +} + pub struct Thread { - id: ThreadId, + id: acp::SessionId, prompt_id: PromptId, + updated_at: DateTime, + title: ThreadTitle, + summary: DetailedSummaryState, messages: Vec, completion_mode: CompletionMode, /// Holds the task that handles agent interaction until the end of the turn. @@ -473,6 +467,9 @@ pub struct Thread { pending_message: Option, tools: BTreeMap>, tool_use_limit_reached: bool, + request_token_usage: Vec, + cumulative_token_usage: TokenUsage, + initial_project_snapshot: Shared>>>, context_server_registry: Entity, profile_id: AgentProfileId, project_context: Rc>, @@ -484,6 +481,7 @@ pub struct Thread { impl Thread { pub fn new( + id: acp::SessionId, project: Entity, project_context: Rc>, context_server_registry: Entity, @@ -494,14 +492,25 @@ impl Thread { ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); Self { - id: ThreadId::new(), + id, prompt_id: PromptId::new(), + updated_at: Utc::now(), + title: ThreadTitle::None, + summary: DetailedSummaryState::default(), messages: Vec::new(), completion_mode: CompletionMode::Normal, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: Vec::new(), + cumulative_token_usage: TokenUsage::default(), + initial_project_snapshot: { + let project_snapshot = Self::project_snapshot(project.clone(), cx); + cx.foreground_executor() + .spawn(async move { Some(project_snapshot.await) }) + .shared() + }, context_server_registry, profile_id, project_context, @@ -512,8 +521,12 @@ impl Thread { } } + pub fn id(&self) -> &acp::SessionId { + &self.id + } + pub fn from_db( - id: ThreadId, + id: acp::SessionId, db_thread: DbThread, project: Entity, project_context: Rc>, @@ -529,12 +542,17 @@ impl Thread { Self { id, prompt_id: PromptId::new(), + title: ThreadTitle::Done(Ok(db_thread.title.clone())), + summary: db_thread.summary, messages: db_thread.messages, completion_mode: CompletionMode::Normal, running_turn: None, pending_message: None, tools: BTreeMap::default(), tool_use_limit_reached: false, + request_token_usage: db_thread.request_token_usage.clone(), + cumulative_token_usage: db_thread.cumulative_token_usage.clone(), + initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), context_server_registry, profile_id, project_context, @@ -542,9 +560,35 @@ impl Thread { model, project, action_log, + updated_at: db_thread.updated_at, // todo!(figure out if we can remove the "recently opened" list) } } + pub fn to_db(&self, cx: &App) -> Task { + let initial_project_snapshot = self.initial_project_snapshot.clone(); + let mut thread = DbThread { + title: self.title.unwrap_or_default(), + messages: self.messages.clone(), + updated_at: self.updated_at.clone(), + summary: self.summary.clone(), + initial_project_snapshot: None, + cumulative_token_usage: self.cumulative_token_usage.clone(), + request_token_usage: self.request_token_usage.clone(), + model: Some(DbLanguageModel { + provider: self.model.provider_id().to_string(), + model: self.model.name().0.to_string(), + }), + completion_mode: Some(self.completion_mode.into()), + profile: Some(self.profile_id.clone()), + }; + + cx.background_spawn(async move { + let initial_project_snapshot = initial_project_snapshot.await; + thread.initial_project_snapshot = initial_project_snapshot; + thread + }) + } + pub fn replay( &mut self, cx: &mut Context, @@ -630,6 +674,122 @@ impl Thread { ); } + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); + + cx.spawn(async move |_, cx| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + + let mut unsaved_buffers = Vec::new(); + cx.update(|app_cx| { + let buffer_store = project.read(app_cx).buffer_store(); + for buffer_handle in buffer_store.read(app_cx).buffers() { + let buffer = buffer_handle.read(app_cx); + if buffer.is_dirty() { + if let Some(file) = buffer.file() { + let path = file.path().to_string_lossy().to_string(); + unsaved_buffers.push(path); + } + } + } + }) + .ok(); + + Arc::new(ProjectSnapshot { + worktree_snapshots, + unsaved_buffer_paths: unsaved_buffers, + timestamp: Utc::now(), + }) + }) + } + + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().to_string(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); + + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; + + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; + + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); + + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); + + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; + + WorktreeSnapshot { + worktree_path, + git_state, + } + }) + } + pub fn project(&self) -> &Entity { &self.project } diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index 01fa77e22ddc58b27d67e8fab6d1cf0bd64ae84e..c320e8ea722ac407940cb1f015c76e9b29b51faf 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -522,7 +522,7 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates}; + use crate::{ContextServerRegistry, Templates, generate_session_id}; use action_log::ActionLog; use client::TelemetrySettings; use fs::Fs; @@ -547,6 +547,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -748,6 +749,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -890,6 +892,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1019,6 +1022,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1157,6 +1161,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project, Rc::default(), context_server_registry, @@ -1267,6 +1272,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1349,6 +1355,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1434,6 +1441,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry.clone(), @@ -1516,6 +1524,7 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let thread = cx.new(|cx| { Thread::new( + generate_session_id(), project.clone(), Rc::default(), context_server_registry,