From ff108c89d70ced314dcde4f6d0e5cba39a4c404f Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Thu, 16 Apr 2026 14:11:06 +0200 Subject: [PATCH] acp: Fix close session not found error (#54009) Follow up to #53999 Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Closes #ISSUE Release Notes: - N/A --- Cargo.lock | 1 + crates/agent/src/agent.rs | 2 +- crates/agent_servers/Cargo.toml | 3 + crates/agent_servers/src/acp.rs | 579 +++++++++++++++++++++++++------- 4 files changed, 460 insertions(+), 125 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 32024536dbcf6c4c004f151ba8f6e837fee767f2..87a8eed0c016b47c36a002e5d3744885c5377bd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -274,6 +274,7 @@ dependencies = [ "libc", "log", "nix 0.29.0", + "piper", "project", "release_channel", "remote", diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index eaa8de69fca2d11434059a945ce05deddf56cb20..fcb901347a12798aa8e2e40942f88b47beee011d 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -944,7 +944,7 @@ impl NativeAgent { if let Some(pending) = self.pending_sessions.get_mut(&id) { pending.ref_count += 1; let task = pending.task.clone(); - return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err))); + return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) }); } let task = self.load_thread(id.clone(), project.clone(), cx); diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 5fbf1e821cb4a41f09c433ec05fdde9fbbde1a9f..85b206248c7e4ccd039bc92e911891a8cf830727 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -68,4 +68,7 @@ indoc.workspace = true acp_thread = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } gpui_tokio.workspace = true +piper = "0.2" +project = { workspace = true, features = ["test-support"] } reqwest_client = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 54c24c91c89cde8faa4ab351aa8990b92b578050..dae7888e65a01b09699aff59a758d200c03087e3 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -9,6 +9,8 @@ use anyhow::anyhow; use collections::HashMap; use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _}; use futures::AsyncBufReadExt as _; +use futures::FutureExt as _; +use futures::future::Shared; use futures::io::BufReader; use project::agent_server_store::{AgentServerCommand, AgentServerStore}; use project::{AgentId, Project}; @@ -25,6 +27,8 @@ use util::ResultExt as _; use util::path_list::PathList; use util::process::Child; +use std::sync::Arc; + use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; @@ -45,19 +49,31 @@ pub struct AcpConnection { telemetry_id: SharedString, connection: Rc, sessions: Rc>>, + pending_sessions: Rc>>, auth_methods: Vec, agent_server_store: WeakEntity, agent_capabilities: acp::AgentCapabilities, default_mode: Option, default_model: Option, default_config_options: HashMap, - child: Child, + child: Option, session_list: Option>, _io_task: Task>, _wait_task: Task>, _stderr_task: Task>, } +struct PendingAcpSession { + task: Shared, Arc>>>, + ref_count: usize, +} + +struct SessionConfigResponse { + modes: Option, + models: Option, + config_options: Option>, +} + struct ConfigOptions { config_options: Rc>>, tx: Rc>>, @@ -81,6 +97,7 @@ pub struct AcpSession { models: Option>>, session_modes: Option>>, config_options: Option, + ref_count: usize, } pub struct AcpSessionList { @@ -393,6 +410,7 @@ impl AcpConnection { connection, telemetry_id, sessions, + pending_sessions: Rc::new(RefCell::new(HashMap::default())), agent_capabilities: response.agent_capabilities, default_mode, default_model, @@ -401,7 +419,7 @@ impl AcpConnection { _io_task: io_task, _wait_task: wait_task, _stderr_task: stderr_task, - child, + child: Some(child), }) } @@ -409,6 +427,143 @@ impl AcpConnection { &self.agent_capabilities.prompt_capabilities } + #[cfg(test)] + fn new_for_test( + connection: Rc, + sessions: Rc>>, + agent_capabilities: acp::AgentCapabilities, + agent_server_store: WeakEntity, + io_task: Task>, + _cx: &mut App, + ) -> Self { + Self { + id: AgentId::new("test"), + telemetry_id: "test".into(), + connection, + sessions, + pending_sessions: Rc::new(RefCell::new(HashMap::default())), + auth_methods: vec![], + agent_server_store, + agent_capabilities, + default_mode: None, + default_model: None, + default_config_options: HashMap::default(), + child: None, + session_list: None, + _io_task: io_task, + _wait_task: Task::ready(Ok(())), + _stderr_task: Task::ready(Ok(())), + } + } + + fn open_or_create_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + rpc_call: impl FnOnce( + Rc, + acp::SessionId, + PathBuf, + ) + -> futures::future::LocalBoxFuture<'static, Result> + + 'static, + cx: &mut App, + ) -> Task>> { + if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { + session.ref_count += 1; + if let Some(thread) = session.thread.upgrade() { + return Task::ready(Ok(thread)); + } + } + + if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) { + pending.ref_count += 1; + let task = pending.task.clone(); + return cx + .foreground_executor() + .spawn(async move { task.await.map_err(|err| anyhow!(err)) }); + } + + // TODO: remove this once ACP supports multiple working directories + let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { + return Task::ready(Err(anyhow!("Working directory cannot be empty"))); + }; + + let shared_task = cx + .spawn({ + let session_id = session_id.clone(); + let this = self.clone(); + async move |cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let thread: Entity = cx.new(|cx| { + AcpThread::new( + None, + title, + Some(work_dirs), + this.clone(), + project, + action_log, + session_id.clone(), + watch::Receiver::constant( + this.agent_capabilities.prompt_capabilities.clone(), + ), + cx, + ) + }); + + let response = + match rpc_call(this.connection.clone(), session_id.clone(), cwd).await { + Ok(response) => response, + Err(err) => { + this.pending_sessions.borrow_mut().remove(&session_id); + return Err(Arc::new(err)); + } + }; + + let (modes, models, config_options) = + config_state(response.modes, response.models, response.config_options); + + if let Some(config_opts) = config_options.as_ref() { + this.apply_default_config_options(&session_id, config_opts, cx); + } + + let ref_count = this + .pending_sessions + .borrow_mut() + .remove(&session_id) + .map_or(1, |pending| pending.ref_count); + + this.sessions.borrow_mut().insert( + session_id, + AcpSession { + thread: thread.downgrade(), + suppress_abort_err: false, + session_modes: modes, + models, + config_options: config_options.map(ConfigOptions::new), + ref_count, + }, + ); + + Ok(thread) + } + }) + .shared(); + + self.pending_sessions.borrow_mut().insert( + session_id, + PendingAcpSession { + task: shared_task.clone(), + ref_count: 1, + }, + ); + + cx.foreground_executor() + .spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) }) + } + fn apply_default_config_options( &self, session_id: &acp::SessionId, @@ -508,7 +663,9 @@ impl AcpConnection { impl Drop for AcpConnection { fn drop(&mut self) { - self.child.kill().log_err(); + if let Some(ref mut child) = self.child { + child.kill().log_err(); + } } } @@ -700,6 +857,7 @@ impl AgentConnection for AcpConnection { session_modes: modes, models, config_options: config_options.map(ConfigOptions::new), + ref_count: 1, }, ); @@ -731,68 +889,30 @@ impl AgentConnection for AcpConnection { "Loading sessions is not supported by this agent.".into() )))); } - // TODO: remove this once ACP supports multiple working directories - let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { - return Task::ready(Err(anyhow!("Working directory cannot be empty"))); - }; let mcp_servers = mcp_servers_for_project(&project, cx); - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread: Entity = cx.new(|cx| { - AcpThread::new( - None, - title, - Some(work_dirs.clone()), - self.clone(), - project, - action_log, - session_id.clone(), - watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()), - cx, - ) - }); - - self.sessions.borrow_mut().insert( - session_id.clone(), - AcpSession { - thread: thread.downgrade(), - suppress_abort_err: false, - session_modes: None, - models: None, - config_options: None, + self.open_or_create_session( + session_id, + project, + work_dirs, + title, + move |connection, session_id, cwd| { + Box::pin(async move { + let response = connection + .load_session( + acp::LoadSessionRequest::new(session_id, cwd).mcp_servers(mcp_servers), + ) + .await + .map_err(map_acp_error)?; + Ok(SessionConfigResponse { + modes: response.modes, + models: response.models, + config_options: response.config_options, + }) + }) }, - ); - - cx.spawn(async move |cx| { - let response = match self - .connection - .load_session( - acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers), - ) - .await - { - Ok(response) => response, - Err(err) => { - self.sessions.borrow_mut().remove(&session_id); - return Err(map_acp_error(err)); - } - }; - - let (modes, models, config_options) = - config_state(response.modes, response.models, response.config_options); - - if let Some(config_opts) = config_options.as_ref() { - self.apply_default_config_options(&session_id, config_opts, cx); - } - - if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { - session.session_modes = modes; - session.models = models; - session.config_options = config_options.map(ConfigOptions::new); - } - - Ok(thread) - }) + cx, + ) } fn resume_session( @@ -813,69 +933,31 @@ impl AgentConnection for AcpConnection { "Resuming sessions is not supported by this agent.".into() )))); } - // TODO: remove this once ACP supports multiple working directories - let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { - return Task::ready(Err(anyhow!("Working directory cannot be empty"))); - }; let mcp_servers = mcp_servers_for_project(&project, cx); - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let thread: Entity = cx.new(|cx| { - AcpThread::new( - None, - title, - Some(work_dirs), - self.clone(), - project, - action_log, - session_id.clone(), - watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()), - cx, - ) - }); - - self.sessions.borrow_mut().insert( - session_id.clone(), - AcpSession { - thread: thread.downgrade(), - suppress_abort_err: false, - session_modes: None, - models: None, - config_options: None, + self.open_or_create_session( + session_id, + project, + work_dirs, + title, + move |connection, session_id, cwd| { + Box::pin(async move { + let response = connection + .resume_session( + acp::ResumeSessionRequest::new(session_id, cwd) + .mcp_servers(mcp_servers), + ) + .await + .map_err(map_acp_error)?; + Ok(SessionConfigResponse { + modes: response.modes, + models: response.models, + config_options: response.config_options, + }) + }) }, - ); - - cx.spawn(async move |cx| { - let response = match self - .connection - .resume_session( - acp::ResumeSessionRequest::new(session_id.clone(), cwd) - .mcp_servers(mcp_servers), - ) - .await - { - Ok(response) => response, - Err(err) => { - self.sessions.borrow_mut().remove(&session_id); - return Err(map_acp_error(err)); - } - }; - - let (modes, models, config_options) = - config_state(response.modes, response.models, response.config_options); - - if let Some(config_opts) = config_options.as_ref() { - self.apply_default_config_options(&session_id, config_opts, cx); - } - - if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { - session.session_modes = modes; - session.models = models; - session.config_options = config_options.map(ConfigOptions::new); - } - - Ok(thread) - }) + cx, + ) } fn supports_close_session(&self) -> bool { @@ -893,12 +975,24 @@ impl AgentConnection for AcpConnection { )))); } + let mut sessions = self.sessions.borrow_mut(); + let Some(session) = sessions.get_mut(session_id) else { + return Task::ready(Ok(())); + }; + + session.ref_count -= 1; + if session.ref_count > 0 { + return Task::ready(Ok(())); + } + + sessions.remove(session_id); + drop(sessions); + let conn = self.connection.clone(); let session_id = session_id.clone(); cx.foreground_executor().spawn(async move { - conn.close_session(acp::CloseSessionRequest::new(session_id.clone())) + conn.close_session(acp::CloseSessionRequest::new(session_id)) .await?; - self.sessions.borrow_mut().remove(&session_id); Ok(()) }) } @@ -1112,6 +1206,8 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error { #[cfg(test)] mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + use super::*; #[test] @@ -1240,6 +1336,241 @@ mod tests { ); assert_eq!(task.label, "Login"); } + + struct FakeAcpAgent { + load_session_count: Arc, + close_session_count: Arc, + } + + #[async_trait::async_trait(?Send)] + impl acp::Agent for FakeAcpAgent { + async fn initialize( + &self, + args: acp::InitializeRequest, + ) -> acp::Result { + Ok( + acp::InitializeResponse::new(args.protocol_version).agent_capabilities( + acp::AgentCapabilities::default() + .load_session(true) + .session_capabilities( + acp::SessionCapabilities::default() + .close(acp::SessionCloseCapabilities::new()), + ), + ), + ) + } + + async fn authenticate( + &self, + _: acp::AuthenticateRequest, + ) -> acp::Result { + Ok(Default::default()) + } + + async fn new_session( + &self, + _: acp::NewSessionRequest, + ) -> acp::Result { + Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused"))) + } + + async fn prompt(&self, _: acp::PromptRequest) -> acp::Result { + Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) + } + + async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> { + Ok(()) + } + + async fn load_session( + &self, + _: acp::LoadSessionRequest, + ) -> acp::Result { + self.load_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::LoadSessionResponse::new()) + } + + async fn close_session( + &self, + _: acp::CloseSessionRequest, + ) -> acp::Result { + self.close_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::CloseSessionResponse::new()) + } + } + + async fn connect_fake_agent( + cx: &mut gpui::TestAppContext, + ) -> ( + Rc, + Entity, + Arc, + Arc, + Task>, + ) { + cx.update(|cx| { + let store = settings::SettingsStore::test(cx); + cx.set_global(store); + }); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/", serde_json::json!({ "a": {} })).await; + let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await; + + let load_count = Arc::new(AtomicUsize::new(0)); + let close_count = Arc::new(AtomicUsize::new(0)); + + let (c2a_reader, c2a_writer) = piper::pipe(4096); + let (a2c_reader, a2c_writer) = piper::pipe(4096); + + let sessions: Rc>> = + Rc::new(RefCell::new(HashMap::default())); + let session_list_container: Rc>>> = + Rc::new(RefCell::new(None)); + + let foreground = cx.foreground_executor().clone(); + + let client_delegate = ClientDelegate { + sessions: sessions.clone(), + session_list: session_list_container, + cx: cx.to_async(), + }; + + let (client_conn, client_io_task) = + acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let fake_agent = FakeAcpAgent { + load_session_count: load_count.clone(), + close_session_count: close_count.clone(), + }; + + let (_, agent_io_task) = + acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let client_io_task = cx.background_spawn(client_io_task); + let agent_io_task = cx.background_spawn(agent_io_task); + + let response = client_conn + .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1)) + .await + .expect("failed to initialize ACP connection"); + + let agent_capabilities = response.agent_capabilities; + + let agent_server_store = + project.read_with(cx, |project, _| project.agent_server_store().downgrade()); + + let connection = cx.update(|cx| { + AcpConnection::new_for_test( + Rc::new(client_conn), + sessions, + agent_capabilities, + agent_server_store, + client_io_task, + cx, + ) + }); + + let keep_agent_alive = cx.background_spawn(async move { + agent_io_task.await.ok(); + anyhow::Ok(()) + }); + + ( + Rc::new(connection), + project, + load_count, + close_count, + keep_agent_alive, + ) + } + + #[gpui::test] + async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) { + let (connection, project, load_count, close_count, _keep_agent_alive) = + connect_fake_agent(cx).await; + + let session_id = acp::SessionId::new("session-1"); + let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); + + // Load the same session twice concurrently — the second call should join + // the pending task rather than issuing a second ACP load_session RPC. + let first_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + let second_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + + let first_thread = first_load.await.expect("first load failed"); + let second_thread = second_load.await.expect("second load failed"); + cx.run_until_parked(); + + assert_eq!( + first_thread.entity_id(), + second_thread.entity_id(), + "concurrent loads for the same session should share one AcpThread" + ); + assert_eq!( + load_count.load(Ordering::SeqCst), + 1, + "underlying ACP load_session should be called exactly once for concurrent loads" + ); + + // The session has ref_count 2. The first close should not send the ACP + // close_session RPC — the session is still referenced. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("first close failed"); + + assert_eq!( + close_count.load(Ordering::SeqCst), + 0, + "ACP close_session should not be sent while ref_count > 0" + ); + assert!( + connection.sessions.borrow().contains_key(&session_id), + "session should still be tracked after first close" + ); + + // The second close drops ref_count to 0 — now the ACP RPC must be sent. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("second close failed"); + cx.run_until_parked(); + + assert_eq!( + close_count.load(Ordering::SeqCst), + 1, + "ACP close_session should be sent exactly once when ref_count reaches 0" + ); + assert!( + !connection.sessions.borrow().contains_key(&session_id), + "session should be removed after final close" + ); + } } fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec {