@@ -461,13 +461,12 @@ impl AcpConnection {
+ 'static,
cx: &mut App,
) -> Task<Result<Entity<AcpThread>>> {
- if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
- session.ref_count += 1;
- if let Some(thread) = session.thread.upgrade() {
- return Task::ready(Ok(thread));
- }
- }
-
+ // Check `pending_sessions` before `sessions` because the session is now
+ // inserted into `sessions` before the load RPC completes (so that
+ // notifications dispatched during history replay can find the thread).
+ // Concurrent loads should still wait for the in-flight task so that
+ // ref-counting happens in one place and the caller sees a fully loaded
+ // session.
if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) {
pending.ref_count += 1;
let task = pending.task.clone();
@@ -476,6 +475,13 @@ impl AcpConnection {
.spawn(async move { task.await.map_err(|err| anyhow!(err)) });
}
+ if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
+ session.ref_count += 1;
+ if let Some(thread) = session.thread.upgrade() {
+ return Task::ready(Ok(thread));
+ }
+ }
+
// TODO: remove this once ACP supports multiple working directories
let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
return Task::ready(Err(anyhow!("Working directory cannot be empty")));
@@ -503,10 +509,27 @@ impl AcpConnection {
)
});
+ // Register the session before awaiting the RPC so that any
+ // `session/update` notifications that arrive during the call
+ // (e.g. history replay during `session/load`) can find the thread.
+ // Modes/models/config are filled in once the response arrives.
+ this.sessions.borrow_mut().insert(
+ session_id.clone(),
+ AcpSession {
+ thread: thread.downgrade(),
+ suppress_abort_err: false,
+ session_modes: None,
+ models: None,
+ config_options: None,
+ ref_count: 1,
+ },
+ );
+
let response =
match rpc_call(this.connection.clone(), session_id.clone(), cwd).await {
Ok(response) => response,
Err(err) => {
+ this.sessions.borrow_mut().remove(&session_id);
this.pending_sessions.borrow_mut().remove(&session_id);
return Err(Arc::new(err));
}
@@ -525,17 +548,23 @@ impl AcpConnection {
.remove(&session_id)
.map_or(1, |pending| pending.ref_count);
- this.sessions.borrow_mut().insert(
- session_id,
- AcpSession {
- thread: thread.downgrade(),
- suppress_abort_err: false,
- session_modes: modes,
- models,
- config_options: config_options.map(ConfigOptions::new),
- ref_count,
- },
- );
+ // If `close_session` ran to completion while the load RPC was in
+ // flight, it will have removed both the pending entry and the
+ // sessions entry (and dispatched the ACP close RPC). In that case
+ // the thread has no live session to attach to, so fail the load
+ // instead of handing back an orphaned thread.
+ {
+ let mut sessions = this.sessions.borrow_mut();
+ let Some(session) = sessions.get_mut(&session_id) else {
+ return Err(Arc::new(anyhow!(
+ "session was closed before load completed"
+ )));
+ };
+ session.session_modes = modes;
+ session.models = models;
+ session.config_options = config_options.map(ConfigOptions::new);
+ session.ref_count = ref_count;
+ }
Ok(thread)
}
@@ -983,12 +1012,43 @@ impl AgentConnection for AcpConnection {
))));
}
+ // If a load is still in flight, decrement its ref count. The pending
+ // entry is the source of truth for how many handles exist during a
+ // load, so we must tick it down here as well as the `sessions` entry
+ // that was pre-registered to receive history-replay notifications.
+ // Only once the pending ref count hits zero do we actually close the
+ // session; the load task will observe the missing sessions entry and
+ // fail with "session was closed before load completed".
+ let pending_ref_count = {
+ let mut pending_sessions = self.pending_sessions.borrow_mut();
+ pending_sessions.get_mut(session_id).map(|pending| {
+ pending.ref_count = pending.ref_count.saturating_sub(1);
+ pending.ref_count
+ })
+ };
+ match pending_ref_count {
+ Some(0) => {
+ self.pending_sessions.borrow_mut().remove(session_id);
+ self.sessions.borrow_mut().remove(session_id);
+
+ let conn = self.connection.clone();
+ let session_id = session_id.clone();
+ return cx.foreground_executor().spawn(async move {
+ conn.close_session(acp::CloseSessionRequest::new(session_id))
+ .await?;
+ Ok(())
+ });
+ }
+ Some(_) => return Task::ready(Ok(())),
+ None => {}
+ }
+
let mut sessions = self.sessions.borrow_mut();
let Some(session) = sessions.get_mut(session_id) else {
return Task::ready(Ok(()));
};
- session.ref_count -= 1;
+ session.ref_count = session.ref_count.saturating_sub(1);
if session.ref_count > 0 {
return Task::ready(Ok(()));
}
@@ -1789,6 +1849,12 @@ mod tests {
struct FakeAcpAgent {
load_session_count: Arc<AtomicUsize>,
close_session_count: Arc<AtomicUsize>,
+ load_session_updates: Rc<RefCell<Vec<acp::SessionUpdate>>>,
+ /// When `Some`, `load_session` will await a message on this receiver
+ /// before returning its response, allowing tests to interleave other
+ /// work (e.g. `close_session`) with an in-flight load.
+ load_session_gate: Rc<RefCell<Option<smol::channel::Receiver<()>>>>,
+ client: Rc<RefCell<Option<Rc<acp::AgentSideConnection>>>>,
}
#[async_trait::async_trait(?Send)]
@@ -1833,9 +1899,35 @@ mod tests {
async fn load_session(
&self,
- _: acp::LoadSessionRequest,
+ args: acp::LoadSessionRequest,
) -> acp::Result<acp::LoadSessionResponse> {
self.load_session_count.fetch_add(1, Ordering::SeqCst);
+
+ // Simulate spec-compliant history replay: send notifications to the
+ // client before responding to the load request.
+ let updates = std::mem::take(&mut *self.load_session_updates.borrow_mut());
+ if !updates.is_empty() {
+ let client = self
+ .client
+ .borrow()
+ .clone()
+ .expect("client should be set before load_session is called");
+ for update in updates {
+ use acp::Client as _;
+ client
+ .session_notification(acp::SessionNotification::new(
+ args.session_id.clone(),
+ update,
+ ))
+ .await?;
+ }
+ }
+
+ let gate = self.load_session_gate.borrow_mut().take();
+ if let Some(gate) = gate {
+ gate.recv().await.ok();
+ }
+
Ok(acp::LoadSessionResponse::new())
}
@@ -1855,6 +1947,8 @@ mod tests {
Entity<project::Project>,
Arc<AtomicUsize>,
Arc<AtomicUsize>,
+ Rc<RefCell<Vec<acp::SessionUpdate>>>,
+ Rc<RefCell<Option<smol::channel::Receiver<()>>>>,
Task<anyhow::Result<()>>,
) {
cx.update(|cx| {
@@ -1868,6 +1962,12 @@ mod tests {
let load_count = Arc::new(AtomicUsize::new(0));
let close_count = Arc::new(AtomicUsize::new(0));
+ let load_session_updates: Rc<RefCell<Vec<acp::SessionUpdate>>> =
+ Rc::new(RefCell::new(Vec::new()));
+ let load_session_gate: Rc<RefCell<Option<smol::channel::Receiver<()>>>> =
+ Rc::new(RefCell::new(None));
+ let agent_client: Rc<RefCell<Option<Rc<acp::AgentSideConnection>>>> =
+ Rc::new(RefCell::new(None));
let (c2a_writer, c2a_reader) = async_pipe::pipe();
let (a2c_writer, a2c_reader) = async_pipe::pipe();
@@ -1896,15 +1996,19 @@ mod tests {
let fake_agent = FakeAcpAgent {
load_session_count: load_count.clone(),
close_session_count: close_count.clone(),
+ load_session_updates: load_session_updates.clone(),
+ load_session_gate: load_session_gate.clone(),
+ client: agent_client.clone(),
};
- let (_, agent_io_task) =
+ let (agent_conn, agent_io_task) =
acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, {
let foreground = foreground.clone();
move |fut| {
foreground.spawn(fut).detach();
}
});
+ *agent_client.borrow_mut() = Some(Rc::new(agent_conn));
let client_io_task = cx.background_spawn(client_io_task);
let agent_io_task = cx.background_spawn(agent_io_task);
@@ -1940,14 +2044,23 @@ mod tests {
project,
load_count,
close_count,
+ load_session_updates,
+ load_session_gate,
keep_agent_alive,
)
}
#[gpui::test]
async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) {
- let (connection, project, load_count, close_count, _keep_agent_alive) =
- connect_fake_agent(cx).await;
+ let (
+ connection,
+ project,
+ load_count,
+ close_count,
+ _load_session_updates,
+ _load_session_gate,
+ _keep_agent_alive,
+ ) = connect_fake_agent(cx).await;
let session_id = acp::SessionId::new("session-1");
let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
@@ -2020,6 +2133,273 @@ mod tests {
"session should be removed after final close"
);
}
+
+ // Regression test: per the ACP spec, an agent replays the entire conversation
+ // history as `session/update` notifications *before* responding to the
+ // `session/load` request. These notifications must be applied to the
+ // reconstructed thread, not dropped because the session hasn't been
+ // registered yet.
+ #[gpui::test]
+ async fn test_load_session_replays_notifications_sent_before_response(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ let (
+ connection,
+ project,
+ _load_count,
+ _close_count,
+ load_session_updates,
+ _load_session_gate,
+ _keep_agent_alive,
+ ) = connect_fake_agent(cx).await;
+
+ // Queue up some history updates that the fake agent will stream to
+ // the client during the `load_session` call, before responding.
+ *load_session_updates.borrow_mut() = vec![
+ acp::SessionUpdate::UserMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
+ acp::TextContent::new(String::from("hello agent")),
+ ))),
+ acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(acp::ContentBlock::Text(
+ acp::TextContent::new(String::from("hi user")),
+ ))),
+ ];
+
+ let session_id = acp::SessionId::new("session-replay");
+ let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+ let thread = cx
+ .update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs,
+ None,
+ cx,
+ )
+ })
+ .await
+ .expect("load_session failed");
+ cx.run_until_parked();
+
+ let entries = thread.read_with(cx, |thread, _| {
+ thread
+ .entries()
+ .iter()
+ .map(|entry| match entry {
+ acp_thread::AgentThreadEntry::UserMessage(_) => "user",
+ acp_thread::AgentThreadEntry::AssistantMessage(_) => "assistant",
+ acp_thread::AgentThreadEntry::ToolCall(_) => "tool_call",
+ acp_thread::AgentThreadEntry::CompletedPlan(_) => "plan",
+ })
+ .collect::<Vec<_>>()
+ });
+
+ assert_eq!(
+ entries,
+ vec!["user", "assistant"],
+ "replayed notifications should be applied to the thread"
+ );
+ }
+
+ // Regression test: if `close_session` is issued while a `load_session`
+ // RPC is still in flight, the close must take effect cleanly — the load
+ // must fail with a recognizable error (not return an orphaned thread),
+ // no entry must remain in `sessions` or `pending_sessions`, and the ACP
+ // `close_session` RPC must be dispatched.
+ #[gpui::test]
+ async fn test_close_session_during_in_flight_load(cx: &mut gpui::TestAppContext) {
+ let (
+ connection,
+ project,
+ load_count,
+ close_count,
+ _load_session_updates,
+ load_session_gate,
+ _keep_agent_alive,
+ ) = connect_fake_agent(cx).await;
+
+ // Install a gate so the fake agent's `load_session` handler parks
+ // before sending its response. We'll close the session while the
+ // load is parked.
+ let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1);
+ *load_session_gate.borrow_mut() = Some(gate_rx);
+
+ let session_id = acp::SessionId::new("session-close-during-load");
+ let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+ let load_task = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs,
+ None,
+ cx,
+ )
+ });
+
+ // Let the load RPC reach the agent and park on the gate.
+ cx.run_until_parked();
+ assert_eq!(
+ load_count.load(Ordering::SeqCst),
+ 1,
+ "load_session RPC should have been dispatched"
+ );
+ assert!(
+ connection
+ .pending_sessions
+ .borrow()
+ .contains_key(&session_id),
+ "pending_sessions entry should exist while load is in flight"
+ );
+ assert!(
+ connection.sessions.borrow().contains_key(&session_id),
+ "sessions entry should be pre-registered to receive replay notifications"
+ );
+
+ // Close the session while the load is still parked. This should take
+ // the pending path and dispatch the ACP close RPC.
+ let close_task = cx.update(|cx| connection.clone().close_session(&session_id, cx));
+
+ // Release the gate so the load RPC can finally respond.
+ gate_tx.send(()).await.expect("gate send failed");
+ drop(gate_tx);
+
+ let load_result = load_task.await;
+ close_task.await.expect("close failed");
+ cx.run_until_parked();
+
+ let err = load_result.expect_err("load should fail after close-during-load");
+ assert!(
+ err.to_string()
+ .contains("session was closed before load completed"),
+ "expected close-during-load error, got: {err}"
+ );
+
+ assert_eq!(
+ close_count.load(Ordering::SeqCst),
+ 1,
+ "ACP close_session should be sent exactly once"
+ );
+ assert!(
+ !connection.sessions.borrow().contains_key(&session_id),
+ "sessions entry should be removed after close-during-load"
+ );
+ assert!(
+ !connection
+ .pending_sessions
+ .borrow()
+ .contains_key(&session_id),
+ "pending_sessions entry should be removed after close-during-load"
+ );
+ }
+
+ // Regression test: when two concurrent `load_session` calls share a pending
+ // task and one of them issues `close_session` before the load RPC
+ // resolves, the remaining load must still succeed and the session must
+ // stay live. If `close_session` incorrectly short-circuits via the
+ // `sessions` path (removing the entry while a load is still in flight),
+ // the pending task will fail and both concurrent loaders will lose
+ // their handle.
+ #[gpui::test]
+ async fn test_close_during_load_preserves_other_concurrent_loader(
+ cx: &mut gpui::TestAppContext,
+ ) {
+ let (
+ connection,
+ project,
+ load_count,
+ close_count,
+ _load_session_updates,
+ load_session_gate,
+ _keep_agent_alive,
+ ) = connect_fake_agent(cx).await;
+
+ let (gate_tx, gate_rx) = smol::channel::bounded::<()>(1);
+ *load_session_gate.borrow_mut() = Some(gate_rx);
+
+ let session_id = acp::SessionId::new("session-concurrent-close");
+ let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+ // Kick off two concurrent loads; the second must join the first's pending
+ // task rather than issuing a second RPC.
+ let first_load = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs.clone(),
+ None,
+ cx,
+ )
+ });
+ let second_load = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs.clone(),
+ None,
+ cx,
+ )
+ });
+
+ cx.run_until_parked();
+ assert_eq!(
+ load_count.load(Ordering::SeqCst),
+ 1,
+ "load_session RPC should only be dispatched once for concurrent loads"
+ );
+
+ // Close one of the two handles while the shared load is still parked.
+ // Because a second loader still holds a pending ref, this should be a
+ // no-op on the wire.
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .expect("close during load failed");
+ assert_eq!(
+ close_count.load(Ordering::SeqCst),
+ 0,
+ "close_session RPC must not be dispatched while another load handle remains"
+ );
+
+ // Release the gate so the load RPC can finally respond.
+ gate_tx.send(()).await.expect("gate send failed");
+ drop(gate_tx);
+
+ let first_thread = first_load.await.expect("first load should still succeed");
+ let second_thread = second_load.await.expect("second load should still succeed");
+ cx.run_until_parked();
+
+ assert_eq!(
+ first_thread.entity_id(),
+ second_thread.entity_id(),
+ "concurrent loads should share one AcpThread"
+ );
+ assert!(
+ connection.sessions.borrow().contains_key(&session_id),
+ "session must remain tracked while a load handle is still outstanding"
+ );
+ assert!(
+ !connection
+ .pending_sessions
+ .borrow()
+ .contains_key(&session_id),
+ "pending_sessions entry should be cleared once the load resolves"
+ );
+
+ // Final close drops ref_count to 0 and dispatches the ACP close RPC.
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .expect("final close failed");
+ cx.run_until_parked();
+ assert_eq!(
+ close_count.load(Ordering::SeqCst),
+ 1,
+ "close_session RPC should fire exactly once when the last handle is released"
+ );
+ assert!(
+ !connection.sessions.borrow().contains_key(&session_id),
+ "session should be removed after final close"
+ );
+ }
}
fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {