acp: Register ACP sessions before load replay (#54431)

Ben Brandt created

Insert sessions before awaiting `session/load` so replayed
`session/update` notifications can find the thread.

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- acp: Fix for some replay events getting dropped when loading a
previous session.

Change summary

crates/agent_servers/src/acp.rs | 426 +++++++++++++++++++++++++++++++++-
1 file changed, 403 insertions(+), 23 deletions(-)

Detailed changes

crates/agent_servers/src/acp.rs 🔗

@@ -461,13 +461,12 @@ impl AcpConnection {
         + 'static,
         cx: &mut App,
     ) -> Task<Result<Entity<AcpThread>>> {
-        if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
-            session.ref_count += 1;
-            if let Some(thread) = session.thread.upgrade() {
-                return Task::ready(Ok(thread));
-            }
-        }
-
+        // Check `pending_sessions` before `sessions` because the session is now
+        // inserted into `sessions` before the load RPC completes (so that
+        // notifications dispatched during history replay can find the thread).
+        // Concurrent loads should still wait for the in-flight task so that
+        // ref-counting happens in one place and the caller sees a fully loaded
+        // session.
         if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) {
             pending.ref_count += 1;
             let task = pending.task.clone();
@@ -476,6 +475,13 @@ impl AcpConnection {
                 .spawn(async move { task.await.map_err(|err| anyhow!(err)) });
         }
 
+        if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
+            session.ref_count += 1;
+            if let Some(thread) = session.thread.upgrade() {
+                return Task::ready(Ok(thread));
+            }
+        }
+
         // TODO: remove this once ACP supports multiple working directories
         let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
             return Task::ready(Err(anyhow!("Working directory cannot be empty")));
@@ -503,10 +509,27 @@ impl AcpConnection {
                         )
                     });
 
+                    // Register the session before awaiting the RPC so that any
+                    // `session/update` notifications that arrive during the call
+                    // (e.g. history replay during `session/load`) can find the thread.
+                    // Modes/models/config are filled in once the response arrives.
+                    this.sessions.borrow_mut().insert(
+                        session_id.clone(),
+                        AcpSession {
+                            thread: thread.downgrade(),
+                            suppress_abort_err: false,
+                            session_modes: None,
+                            models: None,
+                            config_options: None,
+                            ref_count: 1,
+                        },
+                    );
+
                     let response =
                         match rpc_call(this.connection.clone(), session_id.clone(), cwd).await {
                             Ok(response) => response,
                             Err(err) => {
+                                this.sessions.borrow_mut().remove(&session_id);
                                 this.pending_sessions.borrow_mut().remove(&session_id);
                                 return Err(Arc::new(err));
                             }
@@ -525,17 +548,23 @@ impl AcpConnection {
                         .remove(&session_id)
                         .map_or(1, |pending| pending.ref_count);
 
-                    this.sessions.borrow_mut().insert(
-                        session_id,
-                        AcpSession {
-                            thread: thread.downgrade(),
-                            suppress_abort_err: false,
-                            session_modes: modes,
-                            models,
-                            config_options: config_options.map(ConfigOptions::new),
-                            ref_count,
-                        },
-                    );
+                    // If `close_session` ran to completion while the load RPC was in
+                    // flight, it will have removed both the pending entry and the
+                    // sessions entry (and dispatched the ACP close RPC). In that case
+                    // the thread has no live session to attach to, so fail the load
+                    // instead of handing back an orphaned thread.
+                    {
+                        let mut sessions = this.sessions.borrow_mut();
+                        let Some(session) = sessions.get_mut(&session_id) else {
+                            return Err(Arc::new(anyhow!(
+                                "session was closed before load completed"
+                            )));
+                        };
+                        session.session_modes = modes;
+                        session.models = models;
+                        session.config_options = config_options.map(ConfigOptions::new);
+                        session.ref_count = ref_count;
+                    }
 
                     Ok(thread)
                 }
@@ -983,12 +1012,43 @@ impl AgentConnection for AcpConnection {
             ))));
         }
 
+        // If a load is still in flight, decrement its ref count. The pending
+        // entry is the source of truth for how many handles exist during a
+        // load, so we must tick it down here as well as the `sessions` entry
+        // that was pre-registered to receive history-replay notifications.
+        // Only once the pending ref count hits zero do we actually close the
+        // session; the load task will observe the missing sessions entry and
+        // fail with "session was closed before load completed".
+        let pending_ref_count = {
+            let mut pending_sessions = self.pending_sessions.borrow_mut();
+            pending_sessions.get_mut(session_id).map(|pending| {
+                pending.ref_count = pending.ref_count.saturating_sub(1);
+                pending.ref_count
+            })
+        };
+        match pending_ref_count {
+            Some(0) => {
+                self.pending_sessions.borrow_mut().remove(session_id);
+                self.sessions.borrow_mut().remove(session_id);
+
+                let conn = self.connection.clone();
+                let session_id = session_id.clone();
+                return cx.foreground_executor().spawn(async move {
+                    conn.close_session(acp::CloseSessionRequest::new(session_id))
+                        .await?;
+                    Ok(())
+                });
+            }
+            Some(_) => return Task::ready(Ok(())),
+            None => {}
+        }
+
         let mut sessions = self.sessions.borrow_mut();
         let Some(session) = sessions.get_mut(session_id) else {
             return Task::ready(Ok(()));
         };
 
-        session.ref_count -= 1;
+        session.ref_count = session.ref_count.saturating_sub(1);
         if session.ref_count > 0 {
             return Task::ready(Ok(()));
         }
@@ -1789,6 +1849,12 @@ mod tests {
     struct FakeAcpAgent {
         load_session_count: Arc<AtomicUsize>,
         close_session_count: Arc<AtomicUsize>,
+        load_session_updates: Rc<RefCell<Vec<acp::SessionUpdate>>>,
+        /// When `Some`, `load_session` will await a message on this receiver
+        /// before returning its response, allowing tests to interleave other
+        /// work (e.g. `close_session`) with an in-flight load.
+        load_session_gate: Rc<RefCell<Option<smol::channel::Receiver<()>>>>,
+        client: Rc<RefCell<Option<Rc<acp::AgentSideConnection>>>>,
     }
 
     #[async_trait::async_trait(?Send)]
@@ -1833,9 +1899,35 @@ mod tests {
 
         async fn load_session(
             &self,
-            _: acp::LoadSessionRequest,
+            args: acp::LoadSessionRequest,
         ) -> acp::Result<acp::LoadSessionResponse> {
             self.load_session_count.fetch_add(1, Ordering::SeqCst);
+
+            // Simulate spec-compliant history replay: send notifications to the
+            // client before responding to the load request.
+            let updates = std::mem::take(&mut *self.load_session_updates.borrow_mut());
+            if !updates.is_empty() {
+                let client = self
+                    .client
+                    .borrow()
+                    .clone()
+                    .expect("client should be set before load_session is called");
+                for update in updates {
+                    use acp::Client as _;
+                    client
+                        .session_notification(acp::SessionNotification::new(
+                            args.session_id.clone(),
+                            update,
+                        ))
+                        .await?;
+                }
+            }
+
+            let gate = self.load_session_gate.borrow_mut().take();
+            if let Some(gate) = gate {
+                gate.recv().await.ok();
+            }
+
             Ok(acp::LoadSessionResponse::new())
         }
 
@@ -1855,6 +1947,8 @@ mod tests {
         Entity<project::Project>,
         Arc<AtomicUsize>,
         Arc<AtomicUsize>,
+        Rc<RefCell<Vec<acp::SessionUpdate>>>,
+        Rc<RefCell<Option<smol::channel::Receiver<()>>>>,
         Task<anyhow::Result<()>>,
     ) {
         cx.update(|cx| {
@@ -1868,6 +1962,12 @@ mod tests {
 
         let load_count = Arc::new(AtomicUsize::new(0));
         let close_count = Arc::new(AtomicUsize::new(0));
+        let load_session_updates: Rc<RefCell<Vec<acp::SessionUpdate>>> =
+            Rc::new(RefCell::new(Vec::new()));
+        let load_session_gate: Rc<RefCell<Option<smol::channel::Receiver<()>>>> =
+            Rc::new(RefCell::new(None));
+        let agent_client: Rc<RefCell<Option<Rc<acp::AgentSideConnection>>>> =
+            Rc::new(RefCell::new(None));
 
         let (c2a_writer, c2a_reader) = async_pipe::pipe();
         let (a2c_writer, a2c_reader) = async_pipe::pipe();
@@ -1896,15 +1996,19 @@ mod tests {
         let fake_agent = FakeAcpAgent {
             load_session_count: load_count.clone(),
             close_session_count: close_count.clone(),
+            load_session_updates: load_session_updates.clone(),
+            load_session_gate: load_session_gate.clone(),
+            client: agent_client.clone(),
         };
 
-        let (_, agent_io_task) =
+        let (agent_conn, agent_io_task) =
             acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, {
                 let foreground = foreground.clone();
                 move |fut| {
                     foreground.spawn(fut).detach();
                 }
             });
+        *agent_client.borrow_mut() = Some(Rc::new(agent_conn));
 
         let client_io_task = cx.background_spawn(client_io_task);
         let agent_io_task = cx.background_spawn(agent_io_task);
@@ -1940,14 +2044,23 @@ mod tests {
             project,
             load_count,
             close_count,
+            load_session_updates,
+            load_session_gate,
             keep_agent_alive,
         )
     }
 
     #[gpui::test]
     async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) {
-        let (connection, project, load_count, close_count, _keep_agent_alive) =
-            connect_fake_agent(cx).await;
+        let (
+            connection,
+            project,
+            load_count,
+            close_count,
+            _load_session_updates,
+            _load_session_gate,
+            _keep_agent_alive,
+        ) = connect_fake_agent(cx).await;
 
         let session_id = acp::SessionId::new("session-1");
         let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
@@ -2020,6 +2133,273 @@ mod tests {
             "session should be removed after final close"
         );
     }
+
+    // Regression test: per the ACP spec, an agent replays the entire conversation
+    // history as `session/update` notifications *before* responding to the
+    // `session/load` request. These notifications must be applied to the
+    // reconstructed thread, not dropped because the session hasn't been
+    // registered yet.
+    #[gpui::test]
+    async fn test_load_session_replays_notifications_sent_before_response(
+        cx: &mut gpui::TestAppContext,
+    ) {
+        let (
+            connection,
+            project,
+            _load_count,
+            _close_count,
+            load_session_updates,
+            _load_session_gate,
+            _keep_agent_alive,
+        ) = connect_fake_agent(cx).await;
+
+        // Queue up some history updates that the fake agent will stream to
+        // the client during the `load_session` call, before responding.
+        *load_session_updates.borrow_mut() = vec![
+            acp::SessionUpdate::UserMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
+                acp::TextContent::new(String::from("hello agent")),
+            ))),
+            acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
+                acp::TextContent::new(String::from("hi user")),
+            ))),
+        ];
+
+        let session_id = acp::SessionId::new("session-replay");
+        let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+        let thread = cx
+            .update(|cx| {
+                connection.clone().load_session(
+                    session_id.clone(),
+                    project.clone(),
+                    work_dirs,
+                    None,
+                    cx,
+                )
+            })
+            .await
+            .expect("load_session failed");
+        cx.run_until_parked();
+
+        let entries = thread.read_with(cx, |thread, _| {
+            thread
+                .entries()
+                .iter()
+                .map(|entry| match entry {
+                    acp_thread::AgentThreadEntry::UserMessage(_) => "user",
+                    acp_thread::AgentThreadEntry::AssistantMessage(_) => "assistant",
+                    acp_thread::AgentThreadEntry::ToolCall(_) => "tool_call",
+                    acp_thread::AgentThreadEntry::CompletedPlan(_) => "plan",
+                })
+                .collect::<Vec<_>>()
+        });
+
+        assert_eq!(
+            entries,
+            vec!["user", "assistant"],
+            "replayed notifications should be applied to the thread"
+        );
+    }
+
+    // Regression test: if `close_session` is issued while a `load_session`
+    // RPC is still in flight, the close must take effect cleanly — the load
+    // must fail with a recognizable error (not return an orphaned thread),
+    // no entry must remain in `sessions` or `pending_sessions`, and the ACP
+    // `close_session` RPC must be dispatched.
+    #[gpui::test]
+    async fn test_close_session_during_in_flight_load(cx: &mut gpui::TestAppContext) {
+        let (
+            connection,
+            project,
+            load_count,
+            close_count,
+            _load_session_updates,
+            load_session_gate,
+            _keep_agent_alive,
+        ) = connect_fake_agent(cx).await;
+
+        // Install a gate so the fake agent's `load_session` handler parks
+        // before sending its response. We'll close the session while the
+        // load is parked.
+        let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1);
+        *load_session_gate.borrow_mut() = Some(gate_rx);
+
+        let session_id = acp::SessionId::new("session-close-during-load");
+        let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+        let load_task = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                work_dirs,
+                None,
+                cx,
+            )
+        });
+
+        // Let the load RPC reach the agent and park on the gate.
+        cx.run_until_parked();
+        assert_eq!(
+            load_count.load(Ordering::SeqCst),
+            1,
+            "load_session RPC should have been dispatched"
+        );
+        assert!(
+            connection
+                .pending_sessions
+                .borrow()
+                .contains_key(&session_id),
+            "pending_sessions entry should exist while load is in flight"
+        );
+        assert!(
+            connection.sessions.borrow().contains_key(&session_id),
+            "sessions entry should be pre-registered to receive replay notifications"
+        );
+
+        // Close the session while the load is still parked. This should take
+        // the pending path and dispatch the ACP close RPC.
+        let close_task = cx.update(|cx| connection.clone().close_session(&session_id, cx));
+
+        // Release the gate so the load RPC can finally respond.
+        gate_tx.send(()).await.expect("gate send failed");
+        drop(gate_tx);
+
+        let load_result = load_task.await;
+        close_task.await.expect("close failed");
+        cx.run_until_parked();
+
+        let err = load_result.expect_err("load should fail after close-during-load");
+        assert!(
+            err.to_string()
+                .contains("session was closed before load completed"),
+            "expected close-during-load error, got: {err}"
+        );
+
+        assert_eq!(
+            close_count.load(Ordering::SeqCst),
+            1,
+            "ACP close_session should be sent exactly once"
+        );
+        assert!(
+            !connection.sessions.borrow().contains_key(&session_id),
+            "sessions entry should be removed after close-during-load"
+        );
+        assert!(
+            !connection
+                .pending_sessions
+                .borrow()
+                .contains_key(&session_id),
+            "pending_sessions entry should be removed after close-during-load"
+        );
+    }
+
+    // Regression test: when two concurrent `load_session` calls share a pending
+    // task and one of them issues `close_session` before the load RPC
+    // resolves, the remaining load must still succeed and the session must
+    // stay live. If `close_session` incorrectly short-circuits via the
+    // `sessions` path (removing the entry while a load is still in flight),
+    // the pending task will fail and both concurrent loaders will lose
+    // their handle.
+    #[gpui::test]
+    async fn test_close_during_load_preserves_other_concurrent_loader(
+        cx: &mut gpui::TestAppContext,
+    ) {
+        let (
+            connection,
+            project,
+            load_count,
+            close_count,
+            _load_session_updates,
+            load_session_gate,
+            _keep_agent_alive,
+        ) = connect_fake_agent(cx).await;
+
+        let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1);
+        *load_session_gate.borrow_mut() = Some(gate_rx);
+
+        let session_id = acp::SessionId::new("session-concurrent-close");
+        let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+        // Kick off two concurrent loads; the second must join the first's pending
+        // task rather than issuing a second RPC.
+        let first_load = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                work_dirs.clone(),
+                None,
+                cx,
+            )
+        });
+        let second_load = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                work_dirs.clone(),
+                None,
+                cx,
+            )
+        });
+
+        cx.run_until_parked();
+        assert_eq!(
+            load_count.load(Ordering::SeqCst),
+            1,
+            "load_session RPC should only be dispatched once for concurrent loads"
+        );
+
+        // Close one of the two handles while the shared load is still parked.
+        // Because a second loader still holds a pending ref, this should be a
+        // no-op on the wire.
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .expect("close during load failed");
+        assert_eq!(
+            close_count.load(Ordering::SeqCst),
+            0,
+            "close_session RPC must not be dispatched while another load handle remains"
+        );
+
+        // Release the gate so the load RPC can finally respond.
+        gate_tx.send(()).await.expect("gate send failed");
+        drop(gate_tx);
+
+        let first_thread = first_load.await.expect("first load should still succeed");
+        let second_thread = second_load.await.expect("second load should still succeed");
+        cx.run_until_parked();
+
+        assert_eq!(
+            first_thread.entity_id(),
+            second_thread.entity_id(),
+            "concurrent loads should share one AcpThread"
+        );
+        assert!(
+            connection.sessions.borrow().contains_key(&session_id),
+            "session must remain tracked while a load handle is still outstanding"
+        );
+        assert!(
+            !connection
+                .pending_sessions
+                .borrow()
+                .contains_key(&session_id),
+            "pending_sessions entry should be cleared once the load resolves"
+        );
+
+        // Final close drops ref_count to 0 and dispatches the ACP close RPC.
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .expect("final close failed");
+        cx.run_until_parked();
+        assert_eq!(
+            close_count.load(Ordering::SeqCst),
+            1,
+            "close_session RPC should fire exactly once when the last handle is released"
+        );
+        assert!(
+            !connection.sessions.borrow().contains_key(&session_id),
+            "session should be removed after final close"
+        );
+    }
 }
 
 fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {