Fix ACP requests on the foreground thread

Ben Brandt created

Bridge `SentRequest` results through a oneshot-backed future so ACP
requests can be safely awaited from GPUI foreground tasks instead of
using `block_task` outside `ConnectionTo::spawn`.

Change summary

crates/agent/src/tools/edit_file_tool.rs |   6 
crates/agent/src/tools/read_file_tool.rs |   9 
crates/agent_servers/src/acp.rs          | 155 ++++++++++++++-----------
3 files changed, 91 insertions(+), 79 deletions(-)

Detailed changes

crates/agent/src/tools/edit_file_tool.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
     edit_agent::{EditAgent, EditAgentOutputEvent, EditFormat},
 };
 use acp_thread::Diff;
-use agent_client_protocol::schema::{self as acp, ToolCallLocation, ToolCallUpdateFields};
+use agent_client_protocol::schema as acp;
 use anyhow::{Context as _, Result};
 use collections::HashSet;
 use futures::{FutureExt as _, StreamExt as _};
@@ -260,7 +260,7 @@ impl AgentTool for EditFileTool {
                     let abs_path = project.read(cx).absolute_path(&project_path, cx);
                     if let Some(abs_path) = abs_path.clone() {
                         event_stream.update_fields(
-                            ToolCallUpdateFields::new()
+                            acp::ToolCallUpdateFields::new()
                                 .locations(vec![acp::ToolCallLocation::new(abs_path)]),
                         );
                     }
@@ -409,7 +409,7 @@ impl AgentTool for EditFileTool {
                                     range.start.to_point(&buffer.snapshot()).row
                                 }));
                                 if let Some(abs_path) = abs_path.clone() {
-                                    event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
+                                    event_stream.update_fields(acp::ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path).line(line)]));
                                 }
                                 emitted_location = true;
                             }

crates/agent/src/tools/read_file_tool.rs 🔗

@@ -1,5 +1,5 @@
 use action_log::ActionLog;
-use agent_client_protocol::schema::{self as acp, ToolCallUpdateFields};
+use agent_client_protocol::schema as acp;
 use anyhow::{Context as _, Result, anyhow};
 use futures::FutureExt as _;
 use gpui::{App, Entity, SharedString, Task};
@@ -200,7 +200,7 @@ impl AgentTool for ReadFileTool {
             let file_path = input.path.clone();
 
             cx.update(|_cx| {
-                event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![
+                event_stream.update_fields(acp::ToolCallUpdateFields::new().locations(vec![
                     acp::ToolCallLocation::new(&abs_path)
                         .line(input.start_line.map(|line| line.saturating_sub(1))),
                 ]));
@@ -228,7 +228,7 @@ impl AgentTool for ReadFileTool {
                     .context("processing image")
                     .map_err(tool_content_err)?;
 
-                event_stream.update_fields(ToolCallUpdateFields::new().content(vec![
+                event_stream.update_fields(acp::ToolCallUpdateFields::new().content(vec![
                     acp::ToolCallContent::Content(acp::Content::new(acp::ContentBlock::Image(
                         acp::ImageContent::new(language_model_image.source.clone(), "image/png"),
                     ))),
@@ -333,7 +333,7 @@ impl AgentTool for ReadFileTool {
                         text,
                     }
                     .to_string();
-                    event_stream.update_fields(ToolCallUpdateFields::new().content(vec![
+                    event_stream.update_fields(acp::ToolCallUpdateFields::new().content(vec![
                         acp::ToolCallContent::Content(acp::Content::new(markdown)),
                     ]));
                 }
@@ -347,7 +347,6 @@ impl AgentTool for ReadFileTool {
 #[cfg(test)]
 mod test {
     use super::*;
-    use agent_client_protocol::schema as acp;
     use fs::Fs as _;
     use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
     use project::{FakeFs, Project};

crates/agent_servers/src/acp.rs 🔗

@@ -5,13 +5,15 @@ use acp_thread::{
 use acp_tools::{AcpConnectionRegistry, StreamMessage, StreamMessageDirection};
 use action_log::ActionLog;
 use agent_client_protocol::schema::{self as acp, ErrorCode};
-use agent_client_protocol::{Agent, Client, ConnectionTo, JsonRpcResponse, Lines, Responder};
+use agent_client_protocol::{
+    Agent, Client, ConnectionTo, JsonRpcResponse, Lines, Responder, SentRequest,
+};
 use anyhow::anyhow;
 use collections::HashMap;
 use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _};
 use futures::channel::mpsc;
 use futures::io::BufReader;
-use futures::{AsyncBufReadExt as _, StreamExt as _};
+use futures::{AsyncBufReadExt as _, Future, StreamExt as _};
 use project::agent_server_store::AgentServerCommand;
 use project::{AgentId, Project};
 use serde::Deserialize;
@@ -38,6 +40,27 @@ use crate::GEMINI_ID;
 
 pub const GEMINI_TERMINAL_AUTH_METHOD_ID: &str = "spawn-gemini-cli";
 
+/// Converts a [`SentRequest`] into a `Future` that can be safely awaited from
+/// the GPUI foreground thread.
+///
+/// Unlike [`SentRequest::block_task`], which is only safe inside
+/// [`ConnectionTo::spawn`] tasks, this uses [`SentRequest::on_receiving_result`]
+/// to bridge the response through a oneshot channel. The SDK callback is trivial
+/// (just a channel send), so it doesn't meaningfully block the dispatch loop.
+fn into_foreground_future<T: JsonRpcResponse + Send + 'static>(
+    sent: SentRequest<T>,
+) -> impl Future<Output = Result<T, acp::Error>> {
+    let (tx, rx) = futures::channel::oneshot::channel();
+    let spawn_result = sent.on_receiving_result(async move |result| {
+        tx.send(result).ok();
+        Ok(())
+    });
+    async move {
+        spawn_result?;
+        rx.await.map_err(|_| acp::Error::internal_error())?
+    }
+}
+
 #[derive(Debug, Error)]
 #[error("Unsupported version")]
 pub struct UnsupportedVersion;
@@ -135,9 +158,7 @@ impl AgentSessionList for AcpSessionList {
             let acp_request = acp::ListSessionsRequest::new()
                 .cwd(request.cwd)
                 .cursor(request.cursor);
-            let response = conn
-                .send_request(acp_request)
-                .block_task()
+            let response = into_foreground_future(conn.send_request(acp_request))
                 .await
                 .map_err(map_acp_error)?;
             Ok(AgentSessionListResponse {
@@ -451,8 +472,8 @@ impl AcpConnection {
             });
         });
 
-        let response = connection
-            .send_request(
+        let response = into_foreground_future(
+            connection.send_request(
                 acp::InitializeRequest::new(acp::ProtocolVersion::V1)
                     .client_capabilities(
                         acp::ClientCapabilities::new()
@@ -470,9 +491,9 @@ impl AcpConnection {
                         acp::Implementation::new("zed", version)
                             .title(release_channel.map(ToOwned::to_owned)),
                     ),
-            )
-            .block_task()
-            .await?;
+            ),
+        )
+        .await?;
 
         if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
             return Err(UnsupportedVersion.into());
@@ -480,7 +501,9 @@ impl AcpConnection {
 
         let telemetry_id = response
             .agent_info
+            // Use the one the agent provides if we have one
             .map(|info| info.name.into())
+            // Otherwise, just use the name
             .unwrap_or_else(|| agent_id.0.to_string().into());
 
         let session_list = if response
@@ -603,15 +626,15 @@ impl AcpConnection {
                 let config_opts = config_options.clone();
                 let conn = self.connection.clone();
                 async move |_| {
-                    let result = conn
-                        .send_request(acp::SetSessionConfigOptionRequest::new(
+                    let result = into_foreground_future(conn.send_request(
+                        acp::SetSessionConfigOptionRequest::new(
                             session_id,
                             config_id_clone.clone(),
                             default_value_id,
-                        ))
-                        .block_task()
-                        .await
-                        .log_err();
+                        ),
+                    ))
+                    .await
+                    .log_err();
 
                     if result.is_none() {
                         if let Some(initial) = initial_value {
@@ -724,14 +747,12 @@ impl AgentConnection for AcpConnection {
         let mcp_servers = mcp_servers_for_project(&project, cx);
 
         cx.spawn(async move |cx| {
-            let response = self
-                .connection
-                .send_request(
-                    acp::NewSessionRequest::new(cwd.clone()).mcp_servers(mcp_servers),
-                )
-                .block_task()
-                .await
-                .map_err(map_acp_error)?;
+            let response = into_foreground_future(
+                self.connection
+                    .send_request(acp::NewSessionRequest::new(cwd.clone()).mcp_servers(mcp_servers)),
+            )
+            .await
+            .map_err(map_acp_error)?;
 
             let (modes, models, config_options) =
                 config_state(response.modes, response.models, response.config_options);
@@ -753,14 +774,14 @@ impl AgentConnection for AcpConnection {
                             let modes = modes.clone();
                             let conn = self.connection.clone();
                             async move |_| {
-                                let result = conn
-                                    .send_request(acp::SetSessionModeRequest::new(
+                                let result = into_foreground_future(
+                                    conn.send_request(acp::SetSessionModeRequest::new(
                                         session_id,
                                         default_mode,
-                                    ))
-                                    .block_task()
-                                    .await
-                                    .log_err();
+                                    )),
+                                )
+                                .await
+                                .log_err();
 
                                 if result.is_none() {
                                     modes.borrow_mut().current_mode_id = initial_mode_id;
@@ -802,14 +823,14 @@ impl AgentConnection for AcpConnection {
                             let models = models.clone();
                             let conn = self.connection.clone();
                             async move |_| {
-                                let result = conn
-                                    .send_request(acp::SetSessionModelRequest::new(
+                                let result = into_foreground_future(
+                                    conn.send_request(acp::SetSessionModelRequest::new(
                                         session_id,
                                         default_model,
-                                    ))
-                                    .block_task()
-                                    .await
-                                    .log_err();
+                                    )),
+                                )
+                                .await
+                                .log_err();
 
                                 if result.is_none() {
                                     models.borrow_mut().current_model_id = initial_model_id;
@@ -848,6 +869,7 @@ impl AgentConnection for AcpConnection {
                     project,
                     action_log,
                     response.session_id.clone(),
+                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
                     watch::Receiver::constant(
                         self.agent_capabilities.prompt_capabilities.clone(),
                     ),
@@ -927,13 +949,10 @@ impl AgentConnection for AcpConnection {
         );
 
         cx.spawn(async move |cx| {
-            let response = match self
-                .connection
-                .send_request(
-                    acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers),
-                )
-                .block_task()
-                .await
+            let response = match into_foreground_future(self.connection.send_request(
+                acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers),
+            ))
+            .await
             {
                 Ok(response) => response,
                 Err(err) => {
@@ -1010,14 +1029,10 @@ impl AgentConnection for AcpConnection {
         );
 
         cx.spawn(async move |cx| {
-            let response = match self
-                .connection
-                .send_request(
-                    acp::ResumeSessionRequest::new(session_id.clone(), cwd)
-                        .mcp_servers(mcp_servers),
-                )
-                .block_task()
-                .await
+            let response = match into_foreground_future(self.connection.send_request(
+                acp::ResumeSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers),
+            ))
+            .await
             {
                 Ok(response) => response,
                 Err(err) => {
@@ -1061,9 +1076,10 @@ impl AgentConnection for AcpConnection {
         let conn = self.connection.clone();
         let session_id = session_id.clone();
         cx.foreground_executor().spawn(async move {
-            conn.send_request(acp::CloseSessionRequest::new(session_id.clone()))
-                .block_task()
-                .await?;
+            into_foreground_future(
+                conn.send_request(acp::CloseSessionRequest::new(session_id.clone())),
+            )
+            .await?;
             self.sessions.borrow_mut().remove(&session_id);
             Ok(())
         })
@@ -1094,8 +1110,7 @@ impl AgentConnection for AcpConnection {
     fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
         let conn = self.connection.clone();
         cx.foreground_executor().spawn(async move {
-            conn.send_request(acp::AuthenticateRequest::new(method_id))
-                .block_task()
+            into_foreground_future(conn.send_request(acp::AuthenticateRequest::new(method_id)))
                 .await?;
             Ok(())
         })
@@ -1111,7 +1126,7 @@ impl AgentConnection for AcpConnection {
         let sessions = self.sessions.clone();
         let session_id = params.session_id.clone();
         cx.foreground_executor().spawn(async move {
-            let result = conn.send_request(params).block_task().await;
+            let result = into_foreground_future(conn.send_request(params)).await;
 
             let mut suppress_abort_err = false;
 
@@ -1489,10 +1504,10 @@ impl acp_thread::AgentSessionModes for AcpSessionModes {
         };
         let state = self.state.clone();
         cx.foreground_executor().spawn(async move {
-            let result = connection
-                .send_request(acp::SetSessionModeRequest::new(session_id, mode_id))
-                .block_task()
-                .await;
+            let result = into_foreground_future(
+                connection.send_request(acp::SetSessionModeRequest::new(session_id, mode_id)),
+            )
+            .await;
 
             if result.is_err() {
                 state.borrow_mut().current_mode_id = old_mode_id;
@@ -1549,10 +1564,10 @@ impl acp_thread::AgentModelSelector for AcpModelSelector {
         };
         let state = self.state.clone();
         cx.foreground_executor().spawn(async move {
-            let result = connection
-                .send_request(acp::SetSessionModelRequest::new(session_id, model_id))
-                .block_task()
-                .await;
+            let result = into_foreground_future(
+                connection.send_request(acp::SetSessionModelRequest::new(session_id, model_id)),
+            )
+            .await;
 
             if result.is_err() {
                 state.borrow_mut().current_model_id = old_model_id;
@@ -1604,12 +1619,10 @@ impl acp_thread::AgentSessionConfigOptions for AcpSessionConfigOptions {
         let watch_tx = self.watch_tx.clone();
 
         cx.foreground_executor().spawn(async move {
-            let response = connection
-                .send_request(acp::SetSessionConfigOptionRequest::new(
-                    session_id, config_id, value,
-                ))
-                .block_task()
-                .await?;
+            let response = into_foreground_future(connection.send_request(
+                acp::SetSessionConfigOptionRequest::new(session_id, config_id, value),
+            ))
+            .await?;
 
             *state.borrow_mut() = response.config_options.clone();
             watch_tx.borrow_mut().send(()).ok();