From 2ca94a6032cb043c06174ebd84c3b972888b775d Mon Sep 17 00:00:00 2001 From: Ben Brandt Date: Wed, 22 Apr 2026 10:29:34 +0200 Subject: [PATCH] acp: Register ACP sessions before load replay (#54431) Insert sessions before awaiting `session/load` so replayed `session/update` notifications can find the thread. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - acp: Fix for some replay events getting dropped when loading a previous session. --- crates/agent_servers/src/acp.rs | 426 ++++++++++++++++++++++++++++++-- 1 file changed, 403 insertions(+), 23 deletions(-) diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index 62e3a526d00e6358237b8aaec38e523252fb273f..aba5a1b55c566365e7637fbc58d414e9c2825eba 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -461,13 +461,12 @@ impl AcpConnection { + 'static, cx: &mut App, ) -> Task>> { - if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { - session.ref_count += 1; - if let Some(thread) = session.thread.upgrade() { - return Task::ready(Ok(thread)); - } - } - + // Check `pending_sessions` before `sessions` because the session is now + // inserted into `sessions` before the load RPC completes (so that + // notifications dispatched during history replay can find the thread). + // Concurrent loads should still wait for the in-flight task so that + // ref-counting happens in one place and the caller sees a fully loaded + // session. if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) { pending.ref_count += 1; let task = pending.task.clone(); @@ -476,6 +475,13 @@ impl AcpConnection { .spawn(async move { task.await.map_err(|err| anyhow!(err)) }); } + if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) { + session.ref_count += 1; + if let Some(thread) = session.thread.upgrade() { + return Task::ready(Ok(thread)); + } + } + // TODO: remove this once ACP supports multiple working directories let Some(cwd) = work_dirs.ordered_paths().next().cloned() else { return Task::ready(Err(anyhow!("Working directory cannot be empty"))); @@ -503,10 +509,27 @@ impl AcpConnection { ) }); + // Register the session before awaiting the RPC so that any + // `session/update` notifications that arrive during the call + // (e.g. history replay during `session/load`) can find the thread. + // Modes/models/config are filled in once the response arrives. + this.sessions.borrow_mut().insert( + session_id.clone(), + AcpSession { + thread: thread.downgrade(), + suppress_abort_err: false, + session_modes: None, + models: None, + config_options: None, + ref_count: 1, + }, + ); + let response = match rpc_call(this.connection.clone(), session_id.clone(), cwd).await { Ok(response) => response, Err(err) => { + this.sessions.borrow_mut().remove(&session_id); this.pending_sessions.borrow_mut().remove(&session_id); return Err(Arc::new(err)); } @@ -525,17 +548,23 @@ impl AcpConnection { .remove(&session_id) .map_or(1, |pending| pending.ref_count); - this.sessions.borrow_mut().insert( - session_id, - AcpSession { - thread: thread.downgrade(), - suppress_abort_err: false, - session_modes: modes, - models, - config_options: config_options.map(ConfigOptions::new), - ref_count, - }, - ); + // If `close_session` ran to completion while the load RPC was in + // flight, it will have removed both the pending entry and the + // sessions entry (and dispatched the ACP close RPC). In that case + // the thread has no live session to attach to, so fail the load + // instead of handing back an orphaned thread. + { + let mut sessions = this.sessions.borrow_mut(); + let Some(session) = sessions.get_mut(&session_id) else { + return Err(Arc::new(anyhow!( + "session was closed before load completed" + ))); + }; + session.session_modes = modes; + session.models = models; + session.config_options = config_options.map(ConfigOptions::new); + session.ref_count = ref_count; + } Ok(thread) } @@ -983,12 +1012,43 @@ impl AgentConnection for AcpConnection { )))); } + // If a load is still in flight, decrement its ref count. The pending + // entry is the source of truth for how many handles exist during a + // load, so we must tick it down here as well as the `sessions` entry + // that was pre-registered to receive history-replay notifications. + // Only once the pending ref count hits zero do we actually close the + // session; the load task will observe the missing sessions entry and + // fail with "session was closed before load completed". + let pending_ref_count = { + let mut pending_sessions = self.pending_sessions.borrow_mut(); + pending_sessions.get_mut(session_id).map(|pending| { + pending.ref_count = pending.ref_count.saturating_sub(1); + pending.ref_count + }) + }; + match pending_ref_count { + Some(0) => { + self.pending_sessions.borrow_mut().remove(session_id); + self.sessions.borrow_mut().remove(session_id); + + let conn = self.connection.clone(); + let session_id = session_id.clone(); + return cx.foreground_executor().spawn(async move { + conn.close_session(acp::CloseSessionRequest::new(session_id)) + .await?; + Ok(()) + }); + } + Some(_) => return Task::ready(Ok(())), + None => {} + } + let mut sessions = self.sessions.borrow_mut(); let Some(session) = sessions.get_mut(session_id) else { return Task::ready(Ok(())); }; - session.ref_count -= 1; + session.ref_count = session.ref_count.saturating_sub(1); if session.ref_count > 0 { return Task::ready(Ok(())); } @@ -1789,6 +1849,12 @@ mod tests { struct FakeAcpAgent { load_session_count: Arc, close_session_count: Arc, + load_session_updates: Rc>>, + /// When `Some`, `load_session` will await a message on this receiver + /// before returning its response, allowing tests to interleave other + /// work (e.g. `close_session`) with an in-flight load. + load_session_gate: Rc>>>, + client: Rc>>>, } #[async_trait::async_trait(?Send)] @@ -1833,9 +1899,35 @@ mod tests { async fn load_session( &self, - _: acp::LoadSessionRequest, + args: acp::LoadSessionRequest, ) -> acp::Result { self.load_session_count.fetch_add(1, Ordering::SeqCst); + + // Simulate spec-compliant history replay: send notifications to the + // client before responding to the load request. + let updates = std::mem::take(&mut *self.load_session_updates.borrow_mut()); + if !updates.is_empty() { + let client = self + .client + .borrow() + .clone() + .expect("client should be set before load_session is called"); + for update in updates { + use acp::Client as _; + client + .session_notification(acp::SessionNotification::new( + args.session_id.clone(), + update, + )) + .await?; + } + } + + let gate = self.load_session_gate.borrow_mut().take(); + if let Some(gate) = gate { + gate.recv().await.ok(); + } + Ok(acp::LoadSessionResponse::new()) } @@ -1855,6 +1947,8 @@ mod tests { Entity, Arc, Arc, + Rc>>, + Rc>>>, Task>, ) { cx.update(|cx| { @@ -1868,6 +1962,12 @@ mod tests { let load_count = Arc::new(AtomicUsize::new(0)); let close_count = Arc::new(AtomicUsize::new(0)); + let load_session_updates: Rc>> = + Rc::new(RefCell::new(Vec::new())); + let load_session_gate: Rc>>> = + Rc::new(RefCell::new(None)); + let agent_client: Rc>>> = + Rc::new(RefCell::new(None)); let (c2a_writer, c2a_reader) = async_pipe::pipe(); let (a2c_writer, a2c_reader) = async_pipe::pipe(); @@ -1896,15 +1996,19 @@ mod tests { let fake_agent = FakeAcpAgent { load_session_count: load_count.clone(), close_session_count: close_count.clone(), + load_session_updates: load_session_updates.clone(), + load_session_gate: load_session_gate.clone(), + client: agent_client.clone(), }; - let (_, agent_io_task) = + let (agent_conn, agent_io_task) = acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, { let foreground = foreground.clone(); move |fut| { foreground.spawn(fut).detach(); } }); + *agent_client.borrow_mut() = Some(Rc::new(agent_conn)); let client_io_task = cx.background_spawn(client_io_task); let agent_io_task = cx.background_spawn(agent_io_task); @@ -1940,14 +2044,23 @@ mod tests { project, load_count, close_count, + load_session_updates, + load_session_gate, keep_agent_alive, ) } #[gpui::test] async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) { - let (connection, project, load_count, close_count, _keep_agent_alive) = - connect_fake_agent(cx).await; + let ( + connection, + project, + load_count, + close_count, + _load_session_updates, + _load_session_gate, + _keep_agent_alive, + ) = connect_fake_agent(cx).await; let session_id = acp::SessionId::new("session-1"); let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); @@ -2020,6 +2133,273 @@ mod tests { "session should be removed after final close" ); } + + // Regression test: per the ACP spec, an agent replays the entire conversation + // history as `session/update` notifications *before* responding to the + // `session/load` request. These notifications must be applied to the + // reconstructed thread, not dropped because the session hasn't been + // registered yet. + #[gpui::test] + async fn test_load_session_replays_notifications_sent_before_response( + cx: &mut gpui::TestAppContext, + ) { + let ( + connection, + project, + _load_count, + _close_count, + load_session_updates, + _load_session_gate, + _keep_agent_alive, + ) = connect_fake_agent(cx).await; + + // Queue up some history updates that the fake agent will stream to + // the client during the `load_session` call, before responding. + *load_session_updates.borrow_mut() = vec![ + acp::SessionUpdate::UserMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text( + acp::TextContent::new(String::from("hello agent")), + ))), + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text( + acp::TextContent::new(String::from("hi user")), + ))), + ]; + + let session_id = acp::SessionId::new("session-replay"); + let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); + + let thread = cx + .update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs, + None, + cx, + ) + }) + .await + .expect("load_session failed"); + cx.run_until_parked(); + + let entries = thread.read_with(cx, |thread, _| { + thread + .entries() + .iter() + .map(|entry| match entry { + acp_thread::AgentThreadEntry::UserMessage(_) => "user", + acp_thread::AgentThreadEntry::AssistantMessage(_) => "assistant", + acp_thread::AgentThreadEntry::ToolCall(_) => "tool_call", + acp_thread::AgentThreadEntry::CompletedPlan(_) => "plan", + }) + .collect::>() + }); + + assert_eq!( + entries, + vec!["user", "assistant"], + "replayed notifications should be applied to the thread" + ); + } + + // Regression test: if `close_session` is issued while a `load_session` + // RPC is still in flight, the close must take effect cleanly — the load + // must fail with a recognizable error (not return an orphaned thread), + // no entry must remain in `sessions` or `pending_sessions`, and the ACP + // `close_session` RPC must be dispatched. + #[gpui::test] + async fn test_close_session_during_in_flight_load(cx: &mut gpui::TestAppContext) { + let ( + connection, + project, + load_count, + close_count, + _load_session_updates, + load_session_gate, + _keep_agent_alive, + ) = connect_fake_agent(cx).await; + + // Install a gate so the fake agent's `load_session` handler parks + // before sending its response. We'll close the session while the + // load is parked. + let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1); + *load_session_gate.borrow_mut() = Some(gate_rx); + + let session_id = acp::SessionId::new("session-close-during-load"); + let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); + + let load_task = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs, + None, + cx, + ) + }); + + // Let the load RPC reach the agent and park on the gate. + cx.run_until_parked(); + assert_eq!( + load_count.load(Ordering::SeqCst), + 1, + "load_session RPC should have been dispatched" + ); + assert!( + connection + .pending_sessions + .borrow() + .contains_key(&session_id), + "pending_sessions entry should exist while load is in flight" + ); + assert!( + connection.sessions.borrow().contains_key(&session_id), + "sessions entry should be pre-registered to receive replay notifications" + ); + + // Close the session while the load is still parked. This should take + // the pending path and dispatch the ACP close RPC. + let close_task = cx.update(|cx| connection.clone().close_session(&session_id, cx)); + + // Release the gate so the load RPC can finally respond. + gate_tx.send(()).await.expect("gate send failed"); + drop(gate_tx); + + let load_result = load_task.await; + close_task.await.expect("close failed"); + cx.run_until_parked(); + + let err = load_result.expect_err("load should fail after close-during-load"); + assert!( + err.to_string() + .contains("session was closed before load completed"), + "expected close-during-load error, got: {err}" + ); + + assert_eq!( + close_count.load(Ordering::SeqCst), + 1, + "ACP close_session should be sent exactly once" + ); + assert!( + !connection.sessions.borrow().contains_key(&session_id), + "sessions entry should be removed after close-during-load" + ); + assert!( + !connection + .pending_sessions + .borrow() + .contains_key(&session_id), + "pending_sessions entry should be removed after close-during-load" + ); + } + + // Regression test: when two concurrent `load_session` calls share a pending + // task and one of them issues `close_session` before the load RPC + // resolves, the remaining load must still succeed and the session must + // stay live. If `close_session` incorrectly short-circuits via the + // `sessions` path (removing the entry while a load is still in flight), + // the pending task will fail and both concurrent loaders will lose + // their handle. + #[gpui::test] + async fn test_close_during_load_preserves_other_concurrent_loader( + cx: &mut gpui::TestAppContext, + ) { + let ( + connection, + project, + load_count, + close_count, + _load_session_updates, + load_session_gate, + _keep_agent_alive, + ) = connect_fake_agent(cx).await; + + let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1); + *load_session_gate.borrow_mut() = Some(gate_rx); + + let session_id = acp::SessionId::new("session-concurrent-close"); + let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]); + + // Kick off two concurrent loads; the second must join the first's pending + // task rather than issuing a second RPC. + let first_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + let second_load = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + work_dirs.clone(), + None, + cx, + ) + }); + + cx.run_until_parked(); + assert_eq!( + load_count.load(Ordering::SeqCst), + 1, + "load_session RPC should only be dispatched once for concurrent loads" + ); + + // Close one of the two handles while the shared load is still parked. + // Because a second loader still holds a pending ref, this should be a + // no-op on the wire. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("close during load failed"); + assert_eq!( + close_count.load(Ordering::SeqCst), + 0, + "close_session RPC must not be dispatched while another load handle remains" + ); + + // Release the gate so the load RPC can finally respond. + gate_tx.send(()).await.expect("gate send failed"); + drop(gate_tx); + + let first_thread = first_load.await.expect("first load should still succeed"); + let second_thread = second_load.await.expect("second load should still succeed"); + cx.run_until_parked(); + + assert_eq!( + first_thread.entity_id(), + second_thread.entity_id(), + "concurrent loads should share one AcpThread" + ); + assert!( + connection.sessions.borrow().contains_key(&session_id), + "session must remain tracked while a load handle is still outstanding" + ); + assert!( + !connection + .pending_sessions + .borrow() + .contains_key(&session_id), + "pending_sessions entry should be cleared once the load resolves" + ); + + // Final close drops ref_count to 0 and dispatches the ACP close RPC. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .expect("final close failed"); + cx.run_until_parked(); + assert_eq!( + close_count.load(Ordering::SeqCst), + 1, + "close_session RPC should fire exactly once when the last handle is released" + ); + assert!( + !connection.sessions.borrow().contains_key(&session_id), + "session should be removed after final close" + ); + } } fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec {