acp: Fix close session not found error (#54009)

Bennet Bo Fenner created

Follow up to #53999

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Closes #ISSUE

Release Notes:

- N/A

Change summary

Cargo.lock                      |   1 
crates/agent/src/agent.rs       |   2 
crates/agent_servers/Cargo.toml |   3 
crates/agent_servers/src/acp.rs | 579 +++++++++++++++++++++++++++-------
4 files changed, 460 insertions(+), 125 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -274,6 +274,7 @@ dependencies = [
  "libc",
  "log",
  "nix 0.29.0",
+ "piper",
  "project",
  "release_channel",
  "remote",

crates/agent/src/agent.rs 🔗

@@ -944,7 +944,7 @@ impl NativeAgent {
         if let Some(pending) = self.pending_sessions.get_mut(&id) {
             pending.ref_count += 1;
             let task = pending.task.clone();
-            return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err)));
+            return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
         }
 
         let task = self.load_thread(id.clone(), project.clone(), cx);

crates/agent_servers/Cargo.toml 🔗

@@ -68,4 +68,7 @@ indoc.workspace = true
 acp_thread = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 gpui_tokio.workspace = true
+piper = "0.2"
+project = { workspace = true, features = ["test-support"] }
 reqwest_client = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }

crates/agent_servers/src/acp.rs 🔗

@@ -9,6 +9,8 @@ use anyhow::anyhow;
 use collections::HashMap;
 use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
 use futures::AsyncBufReadExt as _;
+use futures::FutureExt as _;
+use futures::future::Shared;
 use futures::io::BufReader;
 use project::agent_server_store::{AgentServerCommand, AgentServerStore};
 use project::{AgentId, Project};
@@ -25,6 +27,8 @@ use util::ResultExt as _;
 use util::path_list::PathList;
 use util::process::Child;
 
+use std::sync::Arc;
+
 use anyhow::{Context as _, Result};
 use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 
@@ -45,19 +49,31 @@ pub struct AcpConnection {
     telemetry_id: SharedString,
     connection: Rc<acp::ClientSideConnection>,
     sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+    pending_sessions: Rc<RefCell<HashMap<acp::SessionId, PendingAcpSession>>>,
     auth_methods: Vec<acp::AuthMethod>,
     agent_server_store: WeakEntity<AgentServerStore>,
     agent_capabilities: acp::AgentCapabilities,
     default_mode: Option<acp::SessionModeId>,
     default_model: Option<acp::ModelId>,
     default_config_options: HashMap<String, String>,
-    child: Child,
+    child: Option<Child>,
     session_list: Option<Rc<AcpSessionList>>,
     _io_task: Task<Result<(), acp::Error>>,
     _wait_task: Task<Result<()>>,
     _stderr_task: Task<Result<()>>,
 }
 
+struct PendingAcpSession {
+    task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
+    ref_count: usize,
+}
+
+struct SessionConfigResponse {
+    modes: Option<acp::SessionModeState>,
+    models: Option<acp::SessionModelState>,
+    config_options: Option<Vec<acp::SessionConfigOption>>,
+}
+
 struct ConfigOptions {
     config_options: Rc<RefCell<Vec<acp::SessionConfigOption>>>,
     tx: Rc<RefCell<watch::Sender<()>>>,
@@ -81,6 +97,7 @@ pub struct AcpSession {
     models: Option<Rc<RefCell<acp::SessionModelState>>>,
     session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
     config_options: Option<ConfigOptions>,
+    ref_count: usize,
 }
 
 pub struct AcpSessionList {
@@ -393,6 +410,7 @@ impl AcpConnection {
             connection,
             telemetry_id,
             sessions,
+            pending_sessions: Rc::new(RefCell::new(HashMap::default())),
             agent_capabilities: response.agent_capabilities,
             default_mode,
             default_model,
@@ -401,7 +419,7 @@ impl AcpConnection {
             _io_task: io_task,
             _wait_task: wait_task,
             _stderr_task: stderr_task,
-            child,
+            child: Some(child),
         })
     }
 
@@ -409,6 +427,143 @@ impl AcpConnection {
         &self.agent_capabilities.prompt_capabilities
     }
 
+    #[cfg(test)]
+    fn new_for_test(
+        connection: Rc<acp::ClientSideConnection>,
+        sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+        agent_capabilities: acp::AgentCapabilities,
+        agent_server_store: WeakEntity<AgentServerStore>,
+        io_task: Task<Result<(), acp::Error>>,
+        _cx: &mut App,
+    ) -> Self {
+        Self {
+            id: AgentId::new("test"),
+            telemetry_id: "test".into(),
+            connection,
+            sessions,
+            pending_sessions: Rc::new(RefCell::new(HashMap::default())),
+            auth_methods: vec![],
+            agent_server_store,
+            agent_capabilities,
+            default_mode: None,
+            default_model: None,
+            default_config_options: HashMap::default(),
+            child: None,
+            session_list: None,
+            _io_task: io_task,
+            _wait_task: Task::ready(Ok(())),
+            _stderr_task: Task::ready(Ok(())),
+        }
+    }
+
+    fn open_or_create_session(
+        self: Rc<Self>,
+        session_id: acp::SessionId,
+        project: Entity<Project>,
+        work_dirs: PathList,
+        title: Option<SharedString>,
+        rpc_call: impl FnOnce(
+            Rc<acp::ClientSideConnection>,
+            acp::SessionId,
+            PathBuf,
+        )
+            -> futures::future::LocalBoxFuture<'static, Result<SessionConfigResponse>>
+        + 'static,
+        cx: &mut App,
+    ) -> Task<Result<Entity<AcpThread>>> {
+        if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
+            session.ref_count += 1;
+            if let Some(thread) = session.thread.upgrade() {
+                return Task::ready(Ok(thread));
+            }
+        }
+
+        if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) {
+            pending.ref_count += 1;
+            let task = pending.task.clone();
+            return cx
+                .foreground_executor()
+                .spawn(async move { task.await.map_err(|err| anyhow!(err)) });
+        }
+
+        // TODO: remove this once ACP supports multiple working directories
+        let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
+            return Task::ready(Err(anyhow!("Working directory cannot be empty")));
+        };
+
+        let shared_task = cx
+            .spawn({
+                let session_id = session_id.clone();
+                let this = self.clone();
+                async move |cx| {
+                    let action_log = cx.new(|_| ActionLog::new(project.clone()));
+                    let thread: Entity<AcpThread> = cx.new(|cx| {
+                        AcpThread::new(
+                            None,
+                            title,
+                            Some(work_dirs),
+                            this.clone(),
+                            project,
+                            action_log,
+                            session_id.clone(),
+                            watch::Receiver::constant(
+                                this.agent_capabilities.prompt_capabilities.clone(),
+                            ),
+                            cx,
+                        )
+                    });
+
+                    let response =
+                        match rpc_call(this.connection.clone(), session_id.clone(), cwd).await {
+                            Ok(response) => response,
+                            Err(err) => {
+                                this.pending_sessions.borrow_mut().remove(&session_id);
+                                return Err(Arc::new(err));
+                            }
+                        };
+
+                    let (modes, models, config_options) =
+                        config_state(response.modes, response.models, response.config_options);
+
+                    if let Some(config_opts) = config_options.as_ref() {
+                        this.apply_default_config_options(&session_id, config_opts, cx);
+                    }
+
+                    let ref_count = this
+                        .pending_sessions
+                        .borrow_mut()
+                        .remove(&session_id)
+                        .map_or(1, |pending| pending.ref_count);
+
+                    this.sessions.borrow_mut().insert(
+                        session_id,
+                        AcpSession {
+                            thread: thread.downgrade(),
+                            suppress_abort_err: false,
+                            session_modes: modes,
+                            models,
+                            config_options: config_options.map(ConfigOptions::new),
+                            ref_count,
+                        },
+                    );
+
+                    Ok(thread)
+                }
+            })
+            .shared();
+
+        self.pending_sessions.borrow_mut().insert(
+            session_id,
+            PendingAcpSession {
+                task: shared_task.clone(),
+                ref_count: 1,
+            },
+        );
+
+        cx.foreground_executor()
+            .spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
+    }
+
     fn apply_default_config_options(
         &self,
         session_id: &acp::SessionId,
@@ -508,7 +663,9 @@ impl AcpConnection {
 
 impl Drop for AcpConnection {
     fn drop(&mut self) {
-        self.child.kill().log_err();
+        if let Some(ref mut child) = self.child {
+            child.kill().log_err();
+        }
     }
 }
 
@@ -700,6 +857,7 @@ impl AgentConnection for AcpConnection {
                     session_modes: modes,
                     models,
                     config_options: config_options.map(ConfigOptions::new),
+                    ref_count: 1,
                 },
             );
 
@@ -731,68 +889,30 @@ impl AgentConnection for AcpConnection {
                 "Loading sessions is not supported by this agent.".into()
             ))));
         }
-        // TODO: remove this once ACP supports multiple working directories
-        let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
-            return Task::ready(Err(anyhow!("Working directory cannot be empty")));
-        };
 
         let mcp_servers = mcp_servers_for_project(&project, cx);
-        let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let thread: Entity<AcpThread> = cx.new(|cx| {
-            AcpThread::new(
-                None,
-                title,
-                Some(work_dirs.clone()),
-                self.clone(),
-                project,
-                action_log,
-                session_id.clone(),
-                watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
-                cx,
-            )
-        });
-
-        self.sessions.borrow_mut().insert(
-            session_id.clone(),
-            AcpSession {
-                thread: thread.downgrade(),
-                suppress_abort_err: false,
-                session_modes: None,
-                models: None,
-                config_options: None,
+        self.open_or_create_session(
+            session_id,
+            project,
+            work_dirs,
+            title,
+            move |connection, session_id, cwd| {
+                Box::pin(async move {
+                    let response = connection
+                        .load_session(
+                            acp::LoadSessionRequest::new(session_id, cwd).mcp_servers(mcp_servers),
+                        )
+                        .await
+                        .map_err(map_acp_error)?;
+                    Ok(SessionConfigResponse {
+                        modes: response.modes,
+                        models: response.models,
+                        config_options: response.config_options,
+                    })
+                })
             },
-        );
-
-        cx.spawn(async move |cx| {
-            let response = match self
-                .connection
-                .load_session(
-                    acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers),
-                )
-                .await
-            {
-                Ok(response) => response,
-                Err(err) => {
-                    self.sessions.borrow_mut().remove(&session_id);
-                    return Err(map_acp_error(err));
-                }
-            };
-
-            let (modes, models, config_options) =
-                config_state(response.modes, response.models, response.config_options);
-
-            if let Some(config_opts) = config_options.as_ref() {
-                self.apply_default_config_options(&session_id, config_opts, cx);
-            }
-
-            if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
-                session.session_modes = modes;
-                session.models = models;
-                session.config_options = config_options.map(ConfigOptions::new);
-            }
-
-            Ok(thread)
-        })
+            cx,
+        )
     }
 
     fn resume_session(
@@ -813,69 +933,31 @@ impl AgentConnection for AcpConnection {
                 "Resuming sessions is not supported by this agent.".into()
             ))));
         }
-        // TODO: remove this once ACP supports multiple working directories
-        let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
-            return Task::ready(Err(anyhow!("Working directory cannot be empty")));
-        };
 
         let mcp_servers = mcp_servers_for_project(&project, cx);
-        let action_log = cx.new(|_| ActionLog::new(project.clone()));
-        let thread: Entity<AcpThread> = cx.new(|cx| {
-            AcpThread::new(
-                None,
-                title,
-                Some(work_dirs),
-                self.clone(),
-                project,
-                action_log,
-                session_id.clone(),
-                watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
-                cx,
-            )
-        });
-
-        self.sessions.borrow_mut().insert(
-            session_id.clone(),
-            AcpSession {
-                thread: thread.downgrade(),
-                suppress_abort_err: false,
-                session_modes: None,
-                models: None,
-                config_options: None,
+        self.open_or_create_session(
+            session_id,
+            project,
+            work_dirs,
+            title,
+            move |connection, session_id, cwd| {
+                Box::pin(async move {
+                    let response = connection
+                        .resume_session(
+                            acp::ResumeSessionRequest::new(session_id, cwd)
+                                .mcp_servers(mcp_servers),
+                        )
+                        .await
+                        .map_err(map_acp_error)?;
+                    Ok(SessionConfigResponse {
+                        modes: response.modes,
+                        models: response.models,
+                        config_options: response.config_options,
+                    })
+                })
             },
-        );
-
-        cx.spawn(async move |cx| {
-            let response = match self
-                .connection
-                .resume_session(
-                    acp::ResumeSessionRequest::new(session_id.clone(), cwd)
-                        .mcp_servers(mcp_servers),
-                )
-                .await
-            {
-                Ok(response) => response,
-                Err(err) => {
-                    self.sessions.borrow_mut().remove(&session_id);
-                    return Err(map_acp_error(err));
-                }
-            };
-
-            let (modes, models, config_options) =
-                config_state(response.modes, response.models, response.config_options);
-
-            if let Some(config_opts) = config_options.as_ref() {
-                self.apply_default_config_options(&session_id, config_opts, cx);
-            }
-
-            if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
-                session.session_modes = modes;
-                session.models = models;
-                session.config_options = config_options.map(ConfigOptions::new);
-            }
-
-            Ok(thread)
-        })
+            cx,
+        )
     }
 
     fn supports_close_session(&self) -> bool {
@@ -893,12 +975,24 @@ impl AgentConnection for AcpConnection {
             ))));
         }
 
+        let mut sessions = self.sessions.borrow_mut();
+        let Some(session) = sessions.get_mut(session_id) else {
+            return Task::ready(Ok(()));
+        };
+
+        session.ref_count -= 1;
+        if session.ref_count > 0 {
+            return Task::ready(Ok(()));
+        }
+
+        sessions.remove(session_id);
+        drop(sessions);
+
         let conn = self.connection.clone();
         let session_id = session_id.clone();
         cx.foreground_executor().spawn(async move {
-            conn.close_session(acp::CloseSessionRequest::new(session_id.clone()))
+            conn.close_session(acp::CloseSessionRequest::new(session_id))
                 .await?;
-            self.sessions.borrow_mut().remove(&session_id);
             Ok(())
         })
     }
@@ -1112,6 +1206,8 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::atomic::{AtomicUsize, Ordering};
+
     use super::*;
 
     #[test]
@@ -1240,6 +1336,241 @@ mod tests {
         );
         assert_eq!(task.label, "Login");
     }
+
+    struct FakeAcpAgent {
+        load_session_count: Arc<AtomicUsize>,
+        close_session_count: Arc<AtomicUsize>,
+    }
+
+    #[async_trait::async_trait(?Send)]
+    impl acp::Agent for FakeAcpAgent {
+        async fn initialize(
+            &self,
+            args: acp::InitializeRequest,
+        ) -> acp::Result<acp::InitializeResponse> {
+            Ok(
+                acp::InitializeResponse::new(args.protocol_version).agent_capabilities(
+                    acp::AgentCapabilities::default()
+                        .load_session(true)
+                        .session_capabilities(
+                            acp::SessionCapabilities::default()
+                                .close(acp::SessionCloseCapabilities::new()),
+                        ),
+                ),
+            )
+        }
+
+        async fn authenticate(
+            &self,
+            _: acp::AuthenticateRequest,
+        ) -> acp::Result<acp::AuthenticateResponse> {
+            Ok(Default::default())
+        }
+
+        async fn new_session(
+            &self,
+            _: acp::NewSessionRequest,
+        ) -> acp::Result<acp::NewSessionResponse> {
+            Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused")))
+        }
+
+        async fn prompt(&self, _: acp::PromptRequest) -> acp::Result<acp::PromptResponse> {
+            Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
+        }
+
+        async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> {
+            Ok(())
+        }
+
+        async fn load_session(
+            &self,
+            _: acp::LoadSessionRequest,
+        ) -> acp::Result<acp::LoadSessionResponse> {
+            self.load_session_count.fetch_add(1, Ordering::SeqCst);
+            Ok(acp::LoadSessionResponse::new())
+        }
+
+        async fn close_session(
+            &self,
+            _: acp::CloseSessionRequest,
+        ) -> acp::Result<acp::CloseSessionResponse> {
+            self.close_session_count.fetch_add(1, Ordering::SeqCst);
+            Ok(acp::CloseSessionResponse::new())
+        }
+    }
+
+    async fn connect_fake_agent(
+        cx: &mut gpui::TestAppContext,
+    ) -> (
+        Rc<AcpConnection>,
+        Entity<project::Project>,
+        Arc<AtomicUsize>,
+        Arc<AtomicUsize>,
+        Task<anyhow::Result<()>>,
+    ) {
+        cx.update(|cx| {
+            let store = settings::SettingsStore::test(cx);
+            cx.set_global(store);
+        });
+
+        let fs = fs::FakeFs::new(cx.executor());
+        fs.insert_tree("/", serde_json::json!({ "a": {} })).await;
+        let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await;
+
+        let load_count = Arc::new(AtomicUsize::new(0));
+        let close_count = Arc::new(AtomicUsize::new(0));
+
+        let (c2a_reader, c2a_writer) = piper::pipe(4096);
+        let (a2c_reader, a2c_writer) = piper::pipe(4096);
+
+        let sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>> =
+            Rc::new(RefCell::new(HashMap::default()));
+        let session_list_container: Rc<RefCell<Option<Rc<AcpSessionList>>>> =
+            Rc::new(RefCell::new(None));
+
+        let foreground = cx.foreground_executor().clone();
+
+        let client_delegate = ClientDelegate {
+            sessions: sessions.clone(),
+            session_list: session_list_container,
+            cx: cx.to_async(),
+        };
+
+        let (client_conn, client_io_task) =
+            acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, {
+                let foreground = foreground.clone();
+                move |fut| {
+                    foreground.spawn(fut).detach();
+                }
+            });
+
+        let fake_agent = FakeAcpAgent {
+            load_session_count: load_count.clone(),
+            close_session_count: close_count.clone(),
+        };
+
+        let (_, agent_io_task) =
+            acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, {
+                let foreground = foreground.clone();
+                move |fut| {
+                    foreground.spawn(fut).detach();
+                }
+            });
+
+        let client_io_task = cx.background_spawn(client_io_task);
+        let agent_io_task = cx.background_spawn(agent_io_task);
+
+        let response = client_conn
+            .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1))
+            .await
+            .expect("failed to initialize ACP connection");
+
+        let agent_capabilities = response.agent_capabilities;
+
+        let agent_server_store =
+            project.read_with(cx, |project, _| project.agent_server_store().downgrade());
+
+        let connection = cx.update(|cx| {
+            AcpConnection::new_for_test(
+                Rc::new(client_conn),
+                sessions,
+                agent_capabilities,
+                agent_server_store,
+                client_io_task,
+                cx,
+            )
+        });
+
+        let keep_agent_alive = cx.background_spawn(async move {
+            agent_io_task.await.ok();
+            anyhow::Ok(())
+        });
+
+        (
+            Rc::new(connection),
+            project,
+            load_count,
+            close_count,
+            keep_agent_alive,
+        )
+    }
+
+    #[gpui::test]
+    async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) {
+        let (connection, project, load_count, close_count, _keep_agent_alive) =
+            connect_fake_agent(cx).await;
+
+        let session_id = acp::SessionId::new("session-1");
+        let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+        // Load the same session twice concurrently — the second call should join
+        // the pending task rather than issuing a second ACP load_session RPC.
+        let first_load = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                work_dirs.clone(),
+                None,
+                cx,
+            )
+        });
+        let second_load = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                work_dirs.clone(),
+                None,
+                cx,
+            )
+        });
+
+        let first_thread = first_load.await.expect("first load failed");
+        let second_thread = second_load.await.expect("second load failed");
+        cx.run_until_parked();
+
+        assert_eq!(
+            first_thread.entity_id(),
+            second_thread.entity_id(),
+            "concurrent loads for the same session should share one AcpThread"
+        );
+        assert_eq!(
+            load_count.load(Ordering::SeqCst),
+            1,
+            "underlying ACP load_session should be called exactly once for concurrent loads"
+        );
+
+        // The session has ref_count 2. The first close should not send the ACP
+        // close_session RPC — the session is still referenced.
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .expect("first close failed");
+
+        assert_eq!(
+            close_count.load(Ordering::SeqCst),
+            0,
+            "ACP close_session should not be sent while ref_count > 0"
+        );
+        assert!(
+            connection.sessions.borrow().contains_key(&session_id),
+            "session should still be tracked after first close"
+        );
+
+        // The second close drops ref_count to 0 — now the ACP RPC must be sent.
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .expect("second close failed");
+        cx.run_until_parked();
+
+        assert_eq!(
+            close_count.load(Ordering::SeqCst),
+            1,
+            "ACP close_session should be sent exactly once when ref_count reaches 0"
+        );
+        assert!(
+            !connection.sessions.borrow().contains_key(&session_id),
+            "session should be removed after final close"
+        );
+    }
 }
 
 fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {