Cargo.lock 🔗
@@ -274,6 +274,7 @@ dependencies = [
"libc",
"log",
"nix 0.29.0",
+ "piper",
"project",
"release_channel",
"remote",
Bennet Bo Fenner created
Follow up to #53999
Self-Review Checklist:
- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable
Closes #ISSUE
Release Notes:
- N/A
Cargo.lock | 1
crates/agent/src/agent.rs | 2
crates/agent_servers/Cargo.toml | 3
crates/agent_servers/src/acp.rs | 579 +++++++++++++++++++++++++++-------
4 files changed, 460 insertions(+), 125 deletions(-)
@@ -274,6 +274,7 @@ dependencies = [
"libc",
"log",
"nix 0.29.0",
+ "piper",
"project",
"release_channel",
"remote",
@@ -944,7 +944,7 @@ impl NativeAgent {
if let Some(pending) = self.pending_sessions.get_mut(&id) {
pending.ref_count += 1;
let task = pending.task.clone();
- return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err)));
+ return cx.background_spawn(async move { task.await.map_err(|err| anyhow!(err)) });
}
let task = self.load_thread(id.clone(), project.clone(), cx);
@@ -68,4 +68,7 @@ indoc.workspace = true
acp_thread = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true
+piper = "0.2"
+project = { workspace = true, features = ["test-support"] }
reqwest_client = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
@@ -9,6 +9,8 @@ use anyhow::anyhow;
use collections::HashMap;
use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
use futures::AsyncBufReadExt as _;
+use futures::FutureExt as _;
+use futures::future::Shared;
use futures::io::BufReader;
use project::agent_server_store::{AgentServerCommand, AgentServerStore};
use project::{AgentId, Project};
@@ -25,6 +27,8 @@ use util::ResultExt as _;
use util::path_list::PathList;
use util::process::Child;
+use std::sync::Arc;
+
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
@@ -45,19 +49,31 @@ pub struct AcpConnection {
telemetry_id: SharedString,
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+ pending_sessions: Rc<RefCell<HashMap<acp::SessionId, PendingAcpSession>>>,
auth_methods: Vec<acp::AuthMethod>,
agent_server_store: WeakEntity<AgentServerStore>,
agent_capabilities: acp::AgentCapabilities,
default_mode: Option<acp::SessionModeId>,
default_model: Option<acp::ModelId>,
default_config_options: HashMap<String, String>,
- child: Child,
+ child: Option<Child>,
session_list: Option<Rc<AcpSessionList>>,
_io_task: Task<Result<(), acp::Error>>,
_wait_task: Task<Result<()>>,
_stderr_task: Task<Result<()>>,
}
+struct PendingAcpSession {
+ task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
+ ref_count: usize,
+}
+
+struct SessionConfigResponse {
+ modes: Option<acp::SessionModeState>,
+ models: Option<acp::SessionModelState>,
+ config_options: Option<Vec<acp::SessionConfigOption>>,
+}
+
struct ConfigOptions {
config_options: Rc<RefCell<Vec<acp::SessionConfigOption>>>,
tx: Rc<RefCell<watch::Sender<()>>>,
@@ -81,6 +97,7 @@ pub struct AcpSession {
models: Option<Rc<RefCell<acp::SessionModelState>>>,
session_modes: Option<Rc<RefCell<acp::SessionModeState>>>,
config_options: Option<ConfigOptions>,
+ ref_count: usize,
}
pub struct AcpSessionList {
@@ -393,6 +410,7 @@ impl AcpConnection {
connection,
telemetry_id,
sessions,
+ pending_sessions: Rc::new(RefCell::new(HashMap::default())),
agent_capabilities: response.agent_capabilities,
default_mode,
default_model,
@@ -401,7 +419,7 @@ impl AcpConnection {
_io_task: io_task,
_wait_task: wait_task,
_stderr_task: stderr_task,
- child,
+ child: Some(child),
})
}
@@ -409,6 +427,143 @@ impl AcpConnection {
&self.agent_capabilities.prompt_capabilities
}
+ #[cfg(test)]
+ fn new_for_test(
+ connection: Rc<acp::ClientSideConnection>,
+ sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+ agent_capabilities: acp::AgentCapabilities,
+ agent_server_store: WeakEntity<AgentServerStore>,
+ io_task: Task<Result<(), acp::Error>>,
+ _cx: &mut App,
+ ) -> Self {
+ Self {
+ id: AgentId::new("test"),
+ telemetry_id: "test".into(),
+ connection,
+ sessions,
+ pending_sessions: Rc::new(RefCell::new(HashMap::default())),
+ auth_methods: vec![],
+ agent_server_store,
+ agent_capabilities,
+ default_mode: None,
+ default_model: None,
+ default_config_options: HashMap::default(),
+ child: None,
+ session_list: None,
+ _io_task: io_task,
+ _wait_task: Task::ready(Ok(())),
+ _stderr_task: Task::ready(Ok(())),
+ }
+ }
+
+ fn open_or_create_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ title: Option<SharedString>,
+ rpc_call: impl FnOnce(
+ Rc<acp::ClientSideConnection>,
+ acp::SessionId,
+ PathBuf,
+ )
+ -> futures::future::LocalBoxFuture<'static, Result<SessionConfigResponse>>
+ + 'static,
+ cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
+ session.ref_count += 1;
+ if let Some(thread) = session.thread.upgrade() {
+ return Task::ready(Ok(thread));
+ }
+ }
+
+ if let Some(pending) = self.pending_sessions.borrow_mut().get_mut(&session_id) {
+ pending.ref_count += 1;
+ let task = pending.task.clone();
+ return cx
+ .foreground_executor()
+ .spawn(async move { task.await.map_err(|err| anyhow!(err)) });
+ }
+
+ // TODO: remove this once ACP supports multiple working directories
+ let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
+ return Task::ready(Err(anyhow!("Working directory cannot be empty")));
+ };
+
+ let shared_task = cx
+ .spawn({
+ let session_id = session_id.clone();
+ let this = self.clone();
+ async move |cx| {
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let thread: Entity<AcpThread> = cx.new(|cx| {
+ AcpThread::new(
+ None,
+ title,
+ Some(work_dirs),
+ this.clone(),
+ project,
+ action_log,
+ session_id.clone(),
+ watch::Receiver::constant(
+ this.agent_capabilities.prompt_capabilities.clone(),
+ ),
+ cx,
+ )
+ });
+
+ let response =
+ match rpc_call(this.connection.clone(), session_id.clone(), cwd).await {
+ Ok(response) => response,
+ Err(err) => {
+ this.pending_sessions.borrow_mut().remove(&session_id);
+ return Err(Arc::new(err));
+ }
+ };
+
+ let (modes, models, config_options) =
+ config_state(response.modes, response.models, response.config_options);
+
+ if let Some(config_opts) = config_options.as_ref() {
+ this.apply_default_config_options(&session_id, config_opts, cx);
+ }
+
+ let ref_count = this
+ .pending_sessions
+ .borrow_mut()
+ .remove(&session_id)
+ .map_or(1, |pending| pending.ref_count);
+
+ this.sessions.borrow_mut().insert(
+ session_id,
+ AcpSession {
+ thread: thread.downgrade(),
+ suppress_abort_err: false,
+ session_modes: modes,
+ models,
+ config_options: config_options.map(ConfigOptions::new),
+ ref_count,
+ },
+ );
+
+ Ok(thread)
+ }
+ })
+ .shared();
+
+ self.pending_sessions.borrow_mut().insert(
+ session_id,
+ PendingAcpSession {
+ task: shared_task.clone(),
+ ref_count: 1,
+ },
+ );
+
+ cx.foreground_executor()
+ .spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
+ }
+
fn apply_default_config_options(
&self,
session_id: &acp::SessionId,
@@ -508,7 +663,9 @@ impl AcpConnection {
impl Drop for AcpConnection {
fn drop(&mut self) {
- self.child.kill().log_err();
+ if let Some(ref mut child) = self.child {
+ child.kill().log_err();
+ }
}
}
@@ -700,6 +857,7 @@ impl AgentConnection for AcpConnection {
session_modes: modes,
models,
config_options: config_options.map(ConfigOptions::new),
+ ref_count: 1,
},
);
@@ -731,68 +889,30 @@ impl AgentConnection for AcpConnection {
"Loading sessions is not supported by this agent.".into()
))));
}
- // TODO: remove this once ACP supports multiple working directories
- let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
- return Task::ready(Err(anyhow!("Working directory cannot be empty")));
- };
let mcp_servers = mcp_servers_for_project(&project, cx);
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let thread: Entity<AcpThread> = cx.new(|cx| {
- AcpThread::new(
- None,
- title,
- Some(work_dirs.clone()),
- self.clone(),
- project,
- action_log,
- session_id.clone(),
- watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
- cx,
- )
- });
-
- self.sessions.borrow_mut().insert(
- session_id.clone(),
- AcpSession {
- thread: thread.downgrade(),
- suppress_abort_err: false,
- session_modes: None,
- models: None,
- config_options: None,
+ self.open_or_create_session(
+ session_id,
+ project,
+ work_dirs,
+ title,
+ move |connection, session_id, cwd| {
+ Box::pin(async move {
+ let response = connection
+ .load_session(
+ acp::LoadSessionRequest::new(session_id, cwd).mcp_servers(mcp_servers),
+ )
+ .await
+ .map_err(map_acp_error)?;
+ Ok(SessionConfigResponse {
+ modes: response.modes,
+ models: response.models,
+ config_options: response.config_options,
+ })
+ })
},
- );
-
- cx.spawn(async move |cx| {
- let response = match self
- .connection
- .load_session(
- acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers),
- )
- .await
- {
- Ok(response) => response,
- Err(err) => {
- self.sessions.borrow_mut().remove(&session_id);
- return Err(map_acp_error(err));
- }
- };
-
- let (modes, models, config_options) =
- config_state(response.modes, response.models, response.config_options);
-
- if let Some(config_opts) = config_options.as_ref() {
- self.apply_default_config_options(&session_id, config_opts, cx);
- }
-
- if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
- session.session_modes = modes;
- session.models = models;
- session.config_options = config_options.map(ConfigOptions::new);
- }
-
- Ok(thread)
- })
+ cx,
+ )
}
fn resume_session(
@@ -813,69 +933,31 @@ impl AgentConnection for AcpConnection {
"Resuming sessions is not supported by this agent.".into()
))));
}
- // TODO: remove this once ACP supports multiple working directories
- let Some(cwd) = work_dirs.ordered_paths().next().cloned() else {
- return Task::ready(Err(anyhow!("Working directory cannot be empty")));
- };
let mcp_servers = mcp_servers_for_project(&project, cx);
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
- let thread: Entity<AcpThread> = cx.new(|cx| {
- AcpThread::new(
- None,
- title,
- Some(work_dirs),
- self.clone(),
- project,
- action_log,
- session_id.clone(),
- watch::Receiver::constant(self.agent_capabilities.prompt_capabilities.clone()),
- cx,
- )
- });
-
- self.sessions.borrow_mut().insert(
- session_id.clone(),
- AcpSession {
- thread: thread.downgrade(),
- suppress_abort_err: false,
- session_modes: None,
- models: None,
- config_options: None,
+ self.open_or_create_session(
+ session_id,
+ project,
+ work_dirs,
+ title,
+ move |connection, session_id, cwd| {
+ Box::pin(async move {
+ let response = connection
+ .resume_session(
+ acp::ResumeSessionRequest::new(session_id, cwd)
+ .mcp_servers(mcp_servers),
+ )
+ .await
+ .map_err(map_acp_error)?;
+ Ok(SessionConfigResponse {
+ modes: response.modes,
+ models: response.models,
+ config_options: response.config_options,
+ })
+ })
},
- );
-
- cx.spawn(async move |cx| {
- let response = match self
- .connection
- .resume_session(
- acp::ResumeSessionRequest::new(session_id.clone(), cwd)
- .mcp_servers(mcp_servers),
- )
- .await
- {
- Ok(response) => response,
- Err(err) => {
- self.sessions.borrow_mut().remove(&session_id);
- return Err(map_acp_error(err));
- }
- };
-
- let (modes, models, config_options) =
- config_state(response.modes, response.models, response.config_options);
-
- if let Some(config_opts) = config_options.as_ref() {
- self.apply_default_config_options(&session_id, config_opts, cx);
- }
-
- if let Some(session) = self.sessions.borrow_mut().get_mut(&session_id) {
- session.session_modes = modes;
- session.models = models;
- session.config_options = config_options.map(ConfigOptions::new);
- }
-
- Ok(thread)
- })
+ cx,
+ )
}
fn supports_close_session(&self) -> bool {
@@ -893,12 +975,24 @@ impl AgentConnection for AcpConnection {
))));
}
+ let mut sessions = self.sessions.borrow_mut();
+ let Some(session) = sessions.get_mut(session_id) else {
+ return Task::ready(Ok(()));
+ };
+
+ session.ref_count -= 1;
+ if session.ref_count > 0 {
+ return Task::ready(Ok(()));
+ }
+
+ sessions.remove(session_id);
+ drop(sessions);
+
let conn = self.connection.clone();
let session_id = session_id.clone();
cx.foreground_executor().spawn(async move {
- conn.close_session(acp::CloseSessionRequest::new(session_id.clone()))
+ conn.close_session(acp::CloseSessionRequest::new(session_id))
.await?;
- self.sessions.borrow_mut().remove(&session_id);
Ok(())
})
}
@@ -1112,6 +1206,8 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error {
#[cfg(test)]
mod tests {
+ use std::sync::atomic::{AtomicUsize, Ordering};
+
use super::*;
#[test]
@@ -1240,6 +1336,241 @@ mod tests {
);
assert_eq!(task.label, "Login");
}
+
+ struct FakeAcpAgent {
+ load_session_count: Arc<AtomicUsize>,
+ close_session_count: Arc<AtomicUsize>,
+ }
+
+ #[async_trait::async_trait(?Send)]
+ impl acp::Agent for FakeAcpAgent {
+ async fn initialize(
+ &self,
+ args: acp::InitializeRequest,
+ ) -> acp::Result<acp::InitializeResponse> {
+ Ok(
+ acp::InitializeResponse::new(args.protocol_version).agent_capabilities(
+ acp::AgentCapabilities::default()
+ .load_session(true)
+ .session_capabilities(
+ acp::SessionCapabilities::default()
+ .close(acp::SessionCloseCapabilities::new()),
+ ),
+ ),
+ )
+ }
+
+ async fn authenticate(
+ &self,
+ _: acp::AuthenticateRequest,
+ ) -> acp::Result<acp::AuthenticateResponse> {
+ Ok(Default::default())
+ }
+
+ async fn new_session(
+ &self,
+ _: acp::NewSessionRequest,
+ ) -> acp::Result<acp::NewSessionResponse> {
+ Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused")))
+ }
+
+ async fn prompt(&self, _: acp::PromptRequest) -> acp::Result<acp::PromptResponse> {
+ Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
+ }
+
+ async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> {
+ Ok(())
+ }
+
+ async fn load_session(
+ &self,
+ _: acp::LoadSessionRequest,
+ ) -> acp::Result<acp::LoadSessionResponse> {
+ self.load_session_count.fetch_add(1, Ordering::SeqCst);
+ Ok(acp::LoadSessionResponse::new())
+ }
+
+ async fn close_session(
+ &self,
+ _: acp::CloseSessionRequest,
+ ) -> acp::Result<acp::CloseSessionResponse> {
+ self.close_session_count.fetch_add(1, Ordering::SeqCst);
+ Ok(acp::CloseSessionResponse::new())
+ }
+ }
+
+ async fn connect_fake_agent(
+ cx: &mut gpui::TestAppContext,
+ ) -> (
+ Rc<AcpConnection>,
+ Entity<project::Project>,
+ Arc<AtomicUsize>,
+ Arc<AtomicUsize>,
+ Task<anyhow::Result<()>>,
+ ) {
+ cx.update(|cx| {
+ let store = settings::SettingsStore::test(cx);
+ cx.set_global(store);
+ });
+
+ let fs = fs::FakeFs::new(cx.executor());
+ fs.insert_tree("/", serde_json::json!({ "a": {} })).await;
+ let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await;
+
+ let load_count = Arc::new(AtomicUsize::new(0));
+ let close_count = Arc::new(AtomicUsize::new(0));
+
+ let (c2a_reader, c2a_writer) = piper::pipe(4096);
+ let (a2c_reader, a2c_writer) = piper::pipe(4096);
+
+ let sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>> =
+ Rc::new(RefCell::new(HashMap::default()));
+ let session_list_container: Rc<RefCell<Option<Rc<AcpSessionList>>>> =
+ Rc::new(RefCell::new(None));
+
+ let foreground = cx.foreground_executor().clone();
+
+ let client_delegate = ClientDelegate {
+ sessions: sessions.clone(),
+ session_list: session_list_container,
+ cx: cx.to_async(),
+ };
+
+ let (client_conn, client_io_task) =
+ acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, {
+ let foreground = foreground.clone();
+ move |fut| {
+ foreground.spawn(fut).detach();
+ }
+ });
+
+ let fake_agent = FakeAcpAgent {
+ load_session_count: load_count.clone(),
+ close_session_count: close_count.clone(),
+ };
+
+ let (_, agent_io_task) =
+ acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, {
+ let foreground = foreground.clone();
+ move |fut| {
+ foreground.spawn(fut).detach();
+ }
+ });
+
+ let client_io_task = cx.background_spawn(client_io_task);
+ let agent_io_task = cx.background_spawn(agent_io_task);
+
+ let response = client_conn
+ .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1))
+ .await
+ .expect("failed to initialize ACP connection");
+
+ let agent_capabilities = response.agent_capabilities;
+
+ let agent_server_store =
+ project.read_with(cx, |project, _| project.agent_server_store().downgrade());
+
+ let connection = cx.update(|cx| {
+ AcpConnection::new_for_test(
+ Rc::new(client_conn),
+ sessions,
+ agent_capabilities,
+ agent_server_store,
+ client_io_task,
+ cx,
+ )
+ });
+
+ let keep_agent_alive = cx.background_spawn(async move {
+ agent_io_task.await.ok();
+ anyhow::Ok(())
+ });
+
+ (
+ Rc::new(connection),
+ project,
+ load_count,
+ close_count,
+ keep_agent_alive,
+ )
+ }
+
+ #[gpui::test]
+ async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut gpui::TestAppContext) {
+ let (connection, project, load_count, close_count, _keep_agent_alive) =
+ connect_fake_agent(cx).await;
+
+ let session_id = acp::SessionId::new("session-1");
+ let work_dirs = util::path_list::PathList::new(&[std::path::Path::new("/a")]);
+
+ // Load the same session twice concurrently — the second call should join
+ // the pending task rather than issuing a second ACP load_session RPC.
+ let first_load = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs.clone(),
+ None,
+ cx,
+ )
+ });
+ let second_load = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ work_dirs.clone(),
+ None,
+ cx,
+ )
+ });
+
+ let first_thread = first_load.await.expect("first load failed");
+ let second_thread = second_load.await.expect("second load failed");
+ cx.run_until_parked();
+
+ assert_eq!(
+ first_thread.entity_id(),
+ second_thread.entity_id(),
+ "concurrent loads for the same session should share one AcpThread"
+ );
+ assert_eq!(
+ load_count.load(Ordering::SeqCst),
+ 1,
+ "underlying ACP load_session should be called exactly once for concurrent loads"
+ );
+
+ // The session has ref_count 2. The first close should not send the ACP
+ // close_session RPC — the session is still referenced.
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .expect("first close failed");
+
+ assert_eq!(
+ close_count.load(Ordering::SeqCst),
+ 0,
+ "ACP close_session should not be sent while ref_count > 0"
+ );
+ assert!(
+ connection.sessions.borrow().contains_key(&session_id),
+ "session should still be tracked after first close"
+ );
+
+ // The second close drops ref_count to 0 — now the ACP RPC must be sent.
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .expect("second close failed");
+ cx.run_until_parked();
+
+ assert_eq!(
+ close_count.load(Ordering::SeqCst),
+ 1,
+ "ACP close_session should be sent exactly once when ref_count reaches 0"
+ );
+ assert!(
+ !connection.sessions.borrow().contains_key(&session_id),
+ "session should be removed after final close"
+ );
+ }
}
fn mcp_servers_for_project(project: &Entity<Project>, cx: &App) -> Vec<acp::McpServer> {