diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index a43241a407bef97ce445431fc49275366232aa88..8f9d23f0007391779a7ae0bd41360a70d034f745 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -6,7 +6,7 @@ use crate::{ edit_agent::{EditAgent, EditAgentOutputEvent, EditFormat}, }; use acp_thread::Diff; -use agent_client_protocol::schema::{self as acp, ToolCallLocation, ToolCallUpdateFields}; +use agent_client_protocol::schema as acp; use anyhow::{Context as _, Result}; use collections::HashSet; use futures::{FutureExt as _, StreamExt as _}; @@ -260,7 +260,7 @@ impl AgentTool for EditFileTool { let abs_path = project.read(cx).absolute_path(&project_path, cx); if let Some(abs_path) = abs_path.clone() { event_stream.update_fields( - ToolCallUpdateFields::new() + acp::ToolCallUpdateFields::new() .locations(vec![acp::ToolCallLocation::new(abs_path)]), ); } @@ -409,7 +409,7 @@ impl AgentTool for EditFileTool { range.start.to_point(&buffer.snapshot()).row })); if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)])); + event_stream.update_fields(acp::ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path).line(line)])); } emitted_location = true; } diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index e59e2e4db3a4ed0840d4ed8908df8b922939c145..5b5f9ca1af6eced21d4bda4b63fe99be8cda040a 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -1,5 +1,5 @@ use action_log::ActionLog; -use agent_client_protocol::schema::{self as acp, ToolCallUpdateFields}; +use agent_client_protocol::schema as acp; use anyhow::{Context as _, Result, anyhow}; use futures::FutureExt as _; use gpui::{App, Entity, SharedString, Task}; @@ -200,7 +200,7 @@ impl AgentTool for ReadFileTool { let file_path = input.path.clone(); cx.update(|_cx| { - event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ + event_stream.update_fields(acp::ToolCallUpdateFields::new().locations(vec![ acp::ToolCallLocation::new(&abs_path) .line(input.start_line.map(|line| line.saturating_sub(1))), ])); @@ -228,7 +228,7 @@ impl AgentTool for ReadFileTool { .context("processing image") .map_err(tool_content_err)?; - event_stream.update_fields(ToolCallUpdateFields::new().content(vec![ + event_stream.update_fields(acp::ToolCallUpdateFields::new().content(vec![ acp::ToolCallContent::Content(acp::Content::new(acp::ContentBlock::Image( acp::ImageContent::new(language_model_image.source.clone(), "image/png"), ))), @@ -333,7 +333,7 @@ impl AgentTool for ReadFileTool { text, } .to_string(); - event_stream.update_fields(ToolCallUpdateFields::new().content(vec![ + event_stream.update_fields(acp::ToolCallUpdateFields::new().content(vec![ acp::ToolCallContent::Content(acp::Content::new(markdown)), ])); } @@ -347,7 +347,6 @@ impl AgentTool for ReadFileTool { #[cfg(test)] mod test { use super::*; - use agent_client_protocol::schema as acp; use fs::Fs as _; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use project::{FakeFs, Project}; diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index b81de679982649cb4999462d62f9b3dc4ee5d766..99c49da208fd739246f75c2251c5a75285f2d4db 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -5,13 +5,15 @@ use acp_thread::{ use acp_tools::{AcpConnectionRegistry, StreamMessage, StreamMessageDirection}; use action_log::ActionLog; use agent_client_protocol::schema::{self as acp, ErrorCode}; -use agent_client_protocol::{Agent, Client, ConnectionTo, JsonRpcResponse, Lines, Responder}; +use agent_client_protocol::{ + Agent, Client, ConnectionTo, JsonRpcResponse, Lines, Responder, SentRequest, +}; use anyhow::anyhow; use collections::HashMap; use feature_flags::{AcpBetaFeatureFlag, FeatureFlagAppExt as _}; use futures::channel::mpsc; use futures::io::BufReader; -use futures::{AsyncBufReadExt as _, StreamExt as _}; +use futures::{AsyncBufReadExt as _, Future, StreamExt as _}; use project::agent_server_store::AgentServerCommand; use project::{AgentId, Project}; use serde::Deserialize; @@ -38,6 +40,27 @@ use crate::GEMINI_ID; pub const GEMINI_TERMINAL_AUTH_METHOD_ID: &str = "spawn-gemini-cli"; +/// Converts a [`SentRequest`] into a `Future` that can be safely awaited from +/// the GPUI foreground thread. +/// +/// Unlike [`SentRequest::block_task`], which is only safe inside +/// [`ConnectionTo::spawn`] tasks, this uses [`SentRequest::on_receiving_result`] +/// to bridge the response through a oneshot channel. The SDK callback is trivial +/// (just a channel send), so it doesn't meaningfully block the dispatch loop. +fn into_foreground_future( + sent: SentRequest, +) -> impl Future> { + let (tx, rx) = futures::channel::oneshot::channel(); + let spawn_result = sent.on_receiving_result(async move |result| { + tx.send(result).ok(); + Ok(()) + }); + async move { + spawn_result?; + rx.await.map_err(|_| acp::Error::internal_error())? + } +} + #[derive(Debug, Error)] #[error("Unsupported version")] pub struct UnsupportedVersion; @@ -135,9 +158,7 @@ impl AgentSessionList for AcpSessionList { let acp_request = acp::ListSessionsRequest::new() .cwd(request.cwd) .cursor(request.cursor); - let response = conn - .send_request(acp_request) - .block_task() + let response = into_foreground_future(conn.send_request(acp_request)) .await .map_err(map_acp_error)?; Ok(AgentSessionListResponse { @@ -451,8 +472,8 @@ impl AcpConnection { }); }); - let response = connection - .send_request( + let response = into_foreground_future( + connection.send_request( acp::InitializeRequest::new(acp::ProtocolVersion::V1) .client_capabilities( acp::ClientCapabilities::new() @@ -470,9 +491,9 @@ impl AcpConnection { acp::Implementation::new("zed", version) .title(release_channel.map(ToOwned::to_owned)), ), - ) - .block_task() - .await?; + ), + ) + .await?; if response.protocol_version < MINIMUM_SUPPORTED_VERSION { return Err(UnsupportedVersion.into()); @@ -480,7 +501,9 @@ impl AcpConnection { let telemetry_id = response .agent_info + // Use the one the agent provides if we have one .map(|info| info.name.into()) + // Otherwise, just use the name .unwrap_or_else(|| agent_id.0.to_string().into()); let session_list = if response @@ -603,15 +626,15 @@ impl AcpConnection { let config_opts = config_options.clone(); let conn = self.connection.clone(); async move |_| { - let result = conn - .send_request(acp::SetSessionConfigOptionRequest::new( + let result = into_foreground_future(conn.send_request( + acp::SetSessionConfigOptionRequest::new( session_id, config_id_clone.clone(), default_value_id, - )) - .block_task() - .await - .log_err(); + ), + )) + .await + .log_err(); if result.is_none() { if let Some(initial) = initial_value { @@ -724,14 +747,12 @@ impl AgentConnection for AcpConnection { let mcp_servers = mcp_servers_for_project(&project, cx); cx.spawn(async move |cx| { - let response = self - .connection - .send_request( - acp::NewSessionRequest::new(cwd.clone()).mcp_servers(mcp_servers), - ) - .block_task() - .await - .map_err(map_acp_error)?; + let response = into_foreground_future( + self.connection + .send_request(acp::NewSessionRequest::new(cwd.clone()).mcp_servers(mcp_servers)), + ) + .await + .map_err(map_acp_error)?; let (modes, models, config_options) = config_state(response.modes, response.models, response.config_options); @@ -753,14 +774,14 @@ impl AgentConnection for AcpConnection { let modes = modes.clone(); let conn = self.connection.clone(); async move |_| { - let result = conn - .send_request(acp::SetSessionModeRequest::new( + let result = into_foreground_future( + conn.send_request(acp::SetSessionModeRequest::new( session_id, default_mode, - )) - .block_task() - .await - .log_err(); + )), + ) + .await + .log_err(); if result.is_none() { modes.borrow_mut().current_mode_id = initial_mode_id; @@ -802,14 +823,14 @@ impl AgentConnection for AcpConnection { let models = models.clone(); let conn = self.connection.clone(); async move |_| { - let result = conn - .send_request(acp::SetSessionModelRequest::new( + let result = into_foreground_future( + conn.send_request(acp::SetSessionModelRequest::new( session_id, default_model, - )) - .block_task() - .await - .log_err(); + )), + ) + .await + .log_err(); if result.is_none() { models.borrow_mut().current_model_id = initial_model_id; @@ -848,6 +869,7 @@ impl AgentConnection for AcpConnection { project, action_log, response.session_id.clone(), + // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically. watch::Receiver::constant( self.agent_capabilities.prompt_capabilities.clone(), ), @@ -927,13 +949,10 @@ impl AgentConnection for AcpConnection { ); cx.spawn(async move |cx| { - let response = match self - .connection - .send_request( - acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers), - ) - .block_task() - .await + let response = match into_foreground_future(self.connection.send_request( + acp::LoadSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers), + )) + .await { Ok(response) => response, Err(err) => { @@ -1010,14 +1029,10 @@ impl AgentConnection for AcpConnection { ); cx.spawn(async move |cx| { - let response = match self - .connection - .send_request( - acp::ResumeSessionRequest::new(session_id.clone(), cwd) - .mcp_servers(mcp_servers), - ) - .block_task() - .await + let response = match into_foreground_future(self.connection.send_request( + acp::ResumeSessionRequest::new(session_id.clone(), cwd).mcp_servers(mcp_servers), + )) + .await { Ok(response) => response, Err(err) => { @@ -1061,9 +1076,10 @@ impl AgentConnection for AcpConnection { let conn = self.connection.clone(); let session_id = session_id.clone(); cx.foreground_executor().spawn(async move { - conn.send_request(acp::CloseSessionRequest::new(session_id.clone())) - .block_task() - .await?; + into_foreground_future( + conn.send_request(acp::CloseSessionRequest::new(session_id.clone())), + ) + .await?; self.sessions.borrow_mut().remove(&session_id); Ok(()) }) @@ -1094,8 +1110,7 @@ impl AgentConnection for AcpConnection { fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { let conn = self.connection.clone(); cx.foreground_executor().spawn(async move { - conn.send_request(acp::AuthenticateRequest::new(method_id)) - .block_task() + into_foreground_future(conn.send_request(acp::AuthenticateRequest::new(method_id))) .await?; Ok(()) }) @@ -1111,7 +1126,7 @@ impl AgentConnection for AcpConnection { let sessions = self.sessions.clone(); let session_id = params.session_id.clone(); cx.foreground_executor().spawn(async move { - let result = conn.send_request(params).block_task().await; + let result = into_foreground_future(conn.send_request(params)).await; let mut suppress_abort_err = false; @@ -1489,10 +1504,10 @@ impl acp_thread::AgentSessionModes for AcpSessionModes { }; let state = self.state.clone(); cx.foreground_executor().spawn(async move { - let result = connection - .send_request(acp::SetSessionModeRequest::new(session_id, mode_id)) - .block_task() - .await; + let result = into_foreground_future( + connection.send_request(acp::SetSessionModeRequest::new(session_id, mode_id)), + ) + .await; if result.is_err() { state.borrow_mut().current_mode_id = old_mode_id; @@ -1549,10 +1564,10 @@ impl acp_thread::AgentModelSelector for AcpModelSelector { }; let state = self.state.clone(); cx.foreground_executor().spawn(async move { - let result = connection - .send_request(acp::SetSessionModelRequest::new(session_id, model_id)) - .block_task() - .await; + let result = into_foreground_future( + connection.send_request(acp::SetSessionModelRequest::new(session_id, model_id)), + ) + .await; if result.is_err() { state.borrow_mut().current_model_id = old_model_id; @@ -1604,12 +1619,10 @@ impl acp_thread::AgentSessionConfigOptions for AcpSessionConfigOptions { let watch_tx = self.watch_tx.clone(); cx.foreground_executor().spawn(async move { - let response = connection - .send_request(acp::SetSessionConfigOptionRequest::new( - session_id, config_id, value, - )) - .block_task() - .await?; + let response = into_foreground_future(connection.send_request( + acp::SetSessionConfigOptionRequest::new(session_id, config_id, value), + )) + .await?; *state.borrow_mut() = response.config_options.clone(); watch_tx.borrow_mut().send(()).ok();