1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use agent_settings::AgentSettings;
7use anyhow::anyhow;
8use collections::HashMap;
9use futures::AsyncBufReadExt as _;
10use futures::channel::oneshot;
11use futures::io::BufReader;
12use project::Project;
13use serde::Deserialize;
14use settings::Settings as _;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 prompt_capabilities: acp::PromptCapabilities,
34 _io_task: Task<Result<()>>,
35 _wait_task: Task<Result<()>>,
36 _stderr_task: Task<Result<()>>,
37}
38
39pub struct AcpSession {
40 thread: WeakEntity<AcpThread>,
41 suppress_abort_err: bool,
42}
43
44pub async fn connect(
45 server_name: SharedString,
46 command: AgentServerCommand,
47 root_dir: &Path,
48 cx: &mut AsyncApp,
49) -> Result<Rc<dyn AgentConnection>> {
50 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
51 Ok(Rc::new(conn) as _)
52}
53
54const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
55
56impl AcpConnection {
57 pub async fn stdio(
58 server_name: SharedString,
59 command: AgentServerCommand,
60 root_dir: &Path,
61 cx: &mut AsyncApp,
62 ) -> Result<Self> {
63 let mut child = util::command::new_smol_command(command.path)
64 .args(command.args.iter().map(|arg| arg.as_str()))
65 .envs(command.env.iter().flatten())
66 .current_dir(root_dir)
67 .stdin(std::process::Stdio::piped())
68 .stdout(std::process::Stdio::piped())
69 .stderr(std::process::Stdio::piped())
70 .kill_on_drop(true)
71 .spawn()?;
72
73 let stdout = child.stdout.take().context("Failed to take stdout")?;
74 let stdin = child.stdin.take().context("Failed to take stdin")?;
75 let stderr = child.stderr.take().context("Failed to take stderr")?;
76 log::trace!("Spawned (pid: {})", child.id());
77
78 let sessions = Rc::new(RefCell::new(HashMap::default()));
79
80 let client = ClientDelegate {
81 sessions: sessions.clone(),
82 cx: cx.clone(),
83 };
84 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
85 let foreground_executor = cx.foreground_executor().clone();
86 move |fut| {
87 foreground_executor.spawn(fut).detach();
88 }
89 });
90
91 let io_task = cx.background_spawn(io_task);
92
93 let stderr_task = cx.background_spawn(async move {
94 let mut stderr = BufReader::new(stderr);
95 let mut line = String::new();
96 while let Ok(n) = stderr.read_line(&mut line).await
97 && n > 0
98 {
99 log::warn!("agent stderr: {}", &line);
100 line.clear();
101 }
102 Ok(())
103 });
104
105 let wait_task = cx.spawn({
106 let sessions = sessions.clone();
107 async move |cx| {
108 let status = child.status().await?;
109
110 for session in sessions.borrow().values() {
111 session
112 .thread
113 .update(cx, |thread, cx| {
114 thread.emit_load_error(LoadError::Exited { status }, cx)
115 })
116 .ok();
117 }
118
119 anyhow::Ok(())
120 }
121 });
122
123 let connection = Rc::new(connection);
124
125 cx.update(|cx| {
126 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
127 registry.set_active_connection(server_name.clone(), &connection, cx)
128 });
129 })?;
130
131 let response = connection
132 .initialize(acp::InitializeRequest {
133 protocol_version: acp::VERSION,
134 client_capabilities: acp::ClientCapabilities {
135 fs: acp::FileSystemCapability {
136 read_text_file: true,
137 write_text_file: true,
138 },
139 terminal: true,
140 },
141 })
142 .await?;
143
144 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
145 return Err(UnsupportedVersion.into());
146 }
147
148 Ok(Self {
149 auth_methods: response.auth_methods,
150 connection,
151 server_name,
152 sessions,
153 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
154 _io_task: io_task,
155 _wait_task: wait_task,
156 _stderr_task: stderr_task,
157 })
158 }
159
160 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
161 &self.prompt_capabilities
162 }
163}
164
165impl AgentConnection for AcpConnection {
166 fn new_thread(
167 self: Rc<Self>,
168 project: Entity<Project>,
169 cwd: &Path,
170 cx: &mut App,
171 ) -> Task<Result<Entity<AcpThread>>> {
172 let conn = self.connection.clone();
173 let sessions = self.sessions.clone();
174 let cwd = cwd.to_path_buf();
175 let context_server_store = project.read(cx).context_server_store().read(cx);
176 let mcp_servers = context_server_store
177 .configured_server_ids()
178 .iter()
179 .filter_map(|id| {
180 let configuration = context_server_store.configuration_for_server(id)?;
181 let command = configuration.command();
182 Some(acp::McpServer {
183 name: id.0.to_string(),
184 command: command.path.clone(),
185 args: command.args.clone(),
186 env: if let Some(env) = command.env.as_ref() {
187 env.iter()
188 .map(|(name, value)| acp::EnvVariable {
189 name: name.clone(),
190 value: value.clone(),
191 })
192 .collect()
193 } else {
194 vec![]
195 },
196 })
197 })
198 .collect();
199
200 cx.spawn(async move |cx| {
201 let response = conn
202 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
203 .await
204 .map_err(|err| {
205 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
206 let mut error = AuthRequired::new();
207
208 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
209 error = error.with_description(err.message);
210 }
211
212 anyhow!(error)
213 } else {
214 anyhow!(err)
215 }
216 })?;
217
218 let session_id = response.session_id;
219 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
220 let thread = cx.new(|cx| {
221 AcpThread::new(
222 self.server_name.clone(),
223 self.clone(),
224 project,
225 action_log,
226 session_id.clone(),
227 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
228 watch::Receiver::constant(self.prompt_capabilities),
229 cx,
230 )
231 })?;
232
233 let session = AcpSession {
234 thread: thread.downgrade(),
235 suppress_abort_err: false,
236 };
237 sessions.borrow_mut().insert(session_id, session);
238
239 Ok(thread)
240 })
241 }
242
243 fn auth_methods(&self) -> &[acp::AuthMethod] {
244 &self.auth_methods
245 }
246
247 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
248 let conn = self.connection.clone();
249 cx.foreground_executor().spawn(async move {
250 let result = conn
251 .authenticate(acp::AuthenticateRequest {
252 method_id: method_id.clone(),
253 })
254 .await?;
255
256 Ok(result)
257 })
258 }
259
260 fn prompt(
261 &self,
262 _id: Option<acp_thread::UserMessageId>,
263 params: acp::PromptRequest,
264 cx: &mut App,
265 ) -> Task<Result<acp::PromptResponse>> {
266 let conn = self.connection.clone();
267 let sessions = self.sessions.clone();
268 let session_id = params.session_id.clone();
269 cx.foreground_executor().spawn(async move {
270 let result = conn.prompt(params).await;
271
272 let mut suppress_abort_err = false;
273
274 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
275 suppress_abort_err = session.suppress_abort_err;
276 session.suppress_abort_err = false;
277 }
278
279 match result {
280 Ok(response) => Ok(response),
281 Err(err) => {
282 if err.code != ErrorCode::INTERNAL_ERROR.code {
283 anyhow::bail!(err)
284 }
285
286 let Some(data) = &err.data else {
287 anyhow::bail!(err)
288 };
289
290 // Temporary workaround until the following PR is generally available:
291 // https://github.com/google-gemini/gemini-cli/pull/6656
292
293 #[derive(Deserialize)]
294 #[serde(deny_unknown_fields)]
295 struct ErrorDetails {
296 details: Box<str>,
297 }
298
299 match serde_json::from_value(data.clone()) {
300 Ok(ErrorDetails { details }) => {
301 if suppress_abort_err
302 && (details.contains("This operation was aborted")
303 || details.contains("The user aborted a request"))
304 {
305 Ok(acp::PromptResponse {
306 stop_reason: acp::StopReason::Cancelled,
307 })
308 } else {
309 Err(anyhow!(details))
310 }
311 }
312 Err(_) => Err(anyhow!(err)),
313 }
314 }
315 }
316 })
317 }
318
319 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
320 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
321 session.suppress_abort_err = true;
322 }
323 let conn = self.connection.clone();
324 let params = acp::CancelNotification {
325 session_id: session_id.clone(),
326 };
327 cx.foreground_executor()
328 .spawn(async move { conn.cancel(params).await })
329 .detach();
330 }
331
332 fn list_commands(&self, session_id: &acp::SessionId, cx: &mut App) -> Task<Result<acp::ListCommandsResponse>> {
333 let conn = self.connection.clone();
334 let session_id = session_id.clone();
335 cx.foreground_executor().spawn(async move {
336 conn.list_commands(acp::ListCommandsRequest { session_id }).await
337 .map_err(Into::into)
338 })
339 }
340
341 fn run_command(&self, request: acp::RunCommandRequest, cx: &mut App) -> Task<Result<()>> {
342 let conn = self.connection.clone();
343 cx.foreground_executor().spawn(async move {
344 conn.run_command(request).await
345 .map_err(Into::into)
346 })
347 }
348
349 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
350 self
351 }
352}
353
354struct ClientDelegate {
355 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
356 cx: AsyncApp,
357}
358
359impl acp::Client for ClientDelegate {
360 async fn request_permission(
361 &self,
362 arguments: acp::RequestPermissionRequest,
363 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
364 let cx = &mut self.cx.clone();
365
366 // If always_allow_tool_actions is enabled, then auto-choose the first "Allow" button
367 if AgentSettings::try_read_global(cx, |settings| settings.always_allow_tool_actions)
368 .unwrap_or(false)
369 {
370 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
371 // some tools would (incorrectly) continue to auto-accept.
372 if let Some(allow_once_option) = arguments.options.iter().find_map(|option| {
373 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
374 Some(option.id.clone())
375 } else {
376 None
377 }
378 }) {
379 return Ok(acp::RequestPermissionResponse {
380 outcome: acp::RequestPermissionOutcome::Selected {
381 option_id: allow_once_option,
382 },
383 });
384 }
385 }
386
387 let rx = self
388 .sessions
389 .borrow()
390 .get(&arguments.session_id)
391 .context("Failed to get session")?
392 .thread
393 .update(cx, |thread, cx| {
394 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
395 })?;
396
397 let result = rx?.await;
398
399 let outcome = match result {
400 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
401 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
402 };
403
404 Ok(acp::RequestPermissionResponse { outcome })
405 }
406
407 async fn write_text_file(
408 &self,
409 arguments: acp::WriteTextFileRequest,
410 ) -> Result<(), acp::Error> {
411 let cx = &mut self.cx.clone();
412 let task = self
413 .sessions
414 .borrow()
415 .get(&arguments.session_id)
416 .context("Failed to get session")?
417 .thread
418 .update(cx, |thread, cx| {
419 thread.write_text_file(arguments.path, arguments.content, cx)
420 })?;
421
422 task.await?;
423
424 Ok(())
425 }
426
427 async fn read_text_file(
428 &self,
429 arguments: acp::ReadTextFileRequest,
430 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
431 let cx = &mut self.cx.clone();
432 let task = self
433 .sessions
434 .borrow()
435 .get(&arguments.session_id)
436 .context("Failed to get session")?
437 .thread
438 .update(cx, |thread, cx| {
439 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
440 })?;
441
442 let content = task.await?;
443
444 Ok(acp::ReadTextFileResponse { content })
445 }
446
447 async fn session_notification(
448 &self,
449 notification: acp::SessionNotification,
450 ) -> Result<(), acp::Error> {
451 let cx = &mut self.cx.clone();
452 let sessions = self.sessions.borrow();
453 let session = sessions
454 .get(¬ification.session_id)
455 .context("Failed to get session")?;
456
457 session.thread.update(cx, |thread, cx| {
458 thread.handle_session_update(notification.update, cx)
459 })??;
460
461 Ok(())
462 }
463}