1use acp_thread::AgentConnection;
2use acp_tools::AcpConnectionRegistry;
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
5use anyhow::anyhow;
6use collections::HashMap;
7use futures::AsyncBufReadExt as _;
8use futures::io::BufReader;
9use project::Project;
10use project::agent_server_store::AgentServerCommand;
11use serde::Deserialize;
12
13use std::path::PathBuf;
14use std::{any::Any, cell::RefCell};
15use std::{path::Path, rc::Rc};
16use thiserror::Error;
17
18use anyhow::{Context as _, Result};
19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
20
21use acp_thread::{AcpThread, AuthRequired, LoadError};
22
23#[derive(Debug, Error)]
24#[error("Unsupported version")]
25pub struct UnsupportedVersion;
26
27pub struct AcpConnection {
28 server_name: SharedString,
29 connection: Rc<acp::ClientSideConnection>,
30 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
31 auth_methods: Vec<acp::AuthMethod>,
32 agent_capabilities: acp::AgentCapabilities,
33 root_dir: PathBuf,
34 _io_task: Task<Result<()>>,
35 _wait_task: Task<Result<()>>,
36 _stderr_task: Task<Result<()>>,
37}
38
39pub struct AcpSession {
40 thread: WeakEntity<AcpThread>,
41 suppress_abort_err: bool,
42}
43
44pub async fn connect(
45 server_name: SharedString,
46 command: AgentServerCommand,
47 root_dir: &Path,
48 is_remote: bool,
49 cx: &mut AsyncApp,
50) -> Result<Rc<dyn AgentConnection>> {
51 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, is_remote, cx).await?;
52 Ok(Rc::new(conn) as _)
53}
54
55const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
56
57impl AcpConnection {
58 pub async fn stdio(
59 server_name: SharedString,
60 command: AgentServerCommand,
61 root_dir: &Path,
62 is_remote: bool,
63 cx: &mut AsyncApp,
64 ) -> Result<Self> {
65 let mut child = util::command::new_smol_command(command.path);
66 child
67 .args(command.args.iter().map(|arg| arg.as_str()))
68 .envs(command.env.iter().flatten())
69 .stdin(std::process::Stdio::piped())
70 .stdout(std::process::Stdio::piped())
71 .stderr(std::process::Stdio::piped())
72 .kill_on_drop(true);
73 if !is_remote {
74 child.current_dir(root_dir);
75 }
76 let mut child = child.spawn()?;
77
78 let stdout = child.stdout.take().context("Failed to take stdout")?;
79 let stdin = child.stdin.take().context("Failed to take stdin")?;
80 let stderr = child.stderr.take().context("Failed to take stderr")?;
81 log::trace!("Spawned (pid: {})", child.id());
82
83 let sessions = Rc::new(RefCell::new(HashMap::default()));
84
85 let client = ClientDelegate {
86 sessions: sessions.clone(),
87 cx: cx.clone(),
88 };
89 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
90 let foreground_executor = cx.foreground_executor().clone();
91 move |fut| {
92 foreground_executor.spawn(fut).detach();
93 }
94 });
95
96 let io_task = cx.background_spawn(io_task);
97
98 let stderr_task = cx.background_spawn(async move {
99 let mut stderr = BufReader::new(stderr);
100 let mut line = String::new();
101 while let Ok(n) = stderr.read_line(&mut line).await
102 && n > 0
103 {
104 log::warn!("agent stderr: {}", &line);
105 line.clear();
106 }
107 Ok(())
108 });
109
110 let wait_task = cx.spawn({
111 let sessions = sessions.clone();
112 async move |cx| {
113 let status = child.status().await?;
114
115 for session in sessions.borrow().values() {
116 session
117 .thread
118 .update(cx, |thread, cx| {
119 thread.emit_load_error(LoadError::Exited { status }, cx)
120 })
121 .ok();
122 }
123
124 anyhow::Ok(())
125 }
126 });
127
128 let connection = Rc::new(connection);
129
130 cx.update(|cx| {
131 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
132 registry.set_active_connection(server_name.clone(), &connection, cx)
133 });
134 })?;
135
136 let response = connection
137 .initialize(acp::InitializeRequest {
138 protocol_version: acp::VERSION,
139 client_capabilities: acp::ClientCapabilities {
140 fs: acp::FileSystemCapability {
141 read_text_file: true,
142 write_text_file: true,
143 },
144 terminal: true,
145 },
146 })
147 .await?;
148
149 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
150 return Err(UnsupportedVersion.into());
151 }
152
153 Ok(Self {
154 auth_methods: response.auth_methods,
155 root_dir: root_dir.to_owned(),
156 connection,
157 server_name,
158 sessions,
159 agent_capabilities: response.agent_capabilities,
160 _io_task: io_task,
161 _wait_task: wait_task,
162 _stderr_task: stderr_task,
163 })
164 }
165
166 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
167 &self.agent_capabilities.prompt_capabilities
168 }
169
170 pub fn root_dir(&self) -> &Path {
171 &self.root_dir
172 }
173}
174
175impl AgentConnection for AcpConnection {
176 fn new_thread(
177 self: Rc<Self>,
178 project: Entity<Project>,
179 cwd: &Path,
180 cx: &mut App,
181 ) -> Task<Result<Entity<AcpThread>>> {
182 let conn = self.connection.clone();
183 let sessions = self.sessions.clone();
184 let cwd = cwd.to_path_buf();
185 let context_server_store = project.read(cx).context_server_store().read(cx);
186 let mcp_servers = if project.read(cx).is_local() {
187 context_server_store
188 .configured_server_ids()
189 .iter()
190 .filter_map(|id| {
191 let configuration = context_server_store.configuration_for_server(id)?;
192 let command = configuration.command();
193 Some(acp::McpServer {
194 name: id.0.to_string(),
195 command: command.path.clone(),
196 args: command.args.clone(),
197 env: if let Some(env) = command.env.as_ref() {
198 env.iter()
199 .map(|(name, value)| acp::EnvVariable {
200 name: name.clone(),
201 value: value.clone(),
202 })
203 .collect()
204 } else {
205 vec![]
206 },
207 })
208 })
209 .collect()
210 } else {
211 // In SSH projects, the external agent is running on the remote
212 // machine, and currently we only run MCP servers on the local
213 // machine. So don't pass any MCP servers to the agent in that case.
214 Vec::new()
215 };
216
217 cx.spawn(async move |cx| {
218 let response = conn
219 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
220 .await
221 .map_err(|err| {
222 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
223 let mut error = AuthRequired::new();
224
225 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
226 error = error.with_description(err.message);
227 }
228
229 anyhow!(error)
230 } else {
231 anyhow!(err)
232 }
233 })?;
234
235 let session_id = response.session_id;
236 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
237 let thread = cx.new(|cx| {
238 AcpThread::new(
239 self.server_name.clone(),
240 self.clone(),
241 project,
242 action_log,
243 session_id.clone(),
244 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
245 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
246 cx,
247 )
248 })?;
249
250 let session = AcpSession {
251 thread: thread.downgrade(),
252 suppress_abort_err: false,
253 };
254 sessions.borrow_mut().insert(session_id, session);
255
256 Ok(thread)
257 })
258 }
259
260 fn auth_methods(&self) -> &[acp::AuthMethod] {
261 &self.auth_methods
262 }
263
264 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
265 let conn = self.connection.clone();
266 cx.foreground_executor().spawn(async move {
267 let result = conn
268 .authenticate(acp::AuthenticateRequest {
269 method_id: method_id.clone(),
270 })
271 .await?;
272
273 Ok(result)
274 })
275 }
276
277 fn prompt(
278 &self,
279 _id: Option<acp_thread::UserMessageId>,
280 params: acp::PromptRequest,
281 cx: &mut App,
282 ) -> Task<Result<acp::PromptResponse>> {
283 let conn = self.connection.clone();
284 let sessions = self.sessions.clone();
285 let session_id = params.session_id.clone();
286 cx.foreground_executor().spawn(async move {
287 let result = conn.prompt(params).await;
288
289 let mut suppress_abort_err = false;
290
291 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
292 suppress_abort_err = session.suppress_abort_err;
293 session.suppress_abort_err = false;
294 }
295
296 match result {
297 Ok(response) => Ok(response),
298 Err(err) => {
299 if err.code != ErrorCode::INTERNAL_ERROR.code {
300 anyhow::bail!(err)
301 }
302
303 let Some(data) = &err.data else {
304 anyhow::bail!(err)
305 };
306
307 // Temporary workaround until the following PR is generally available:
308 // https://github.com/google-gemini/gemini-cli/pull/6656
309
310 #[derive(Deserialize)]
311 #[serde(deny_unknown_fields)]
312 struct ErrorDetails {
313 details: Box<str>,
314 }
315
316 match serde_json::from_value(data.clone()) {
317 Ok(ErrorDetails { details }) => {
318 if suppress_abort_err
319 && (details.contains("This operation was aborted")
320 || details.contains("The user aborted a request"))
321 {
322 Ok(acp::PromptResponse {
323 stop_reason: acp::StopReason::Cancelled,
324 })
325 } else {
326 Err(anyhow!(details))
327 }
328 }
329 Err(_) => Err(anyhow!(err)),
330 }
331 }
332 }
333 })
334 }
335
336 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
337 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
338 session.suppress_abort_err = true;
339 }
340 let conn = self.connection.clone();
341 let params = acp::CancelNotification {
342 session_id: session_id.clone(),
343 };
344 cx.foreground_executor()
345 .spawn(async move { conn.cancel(params).await })
346 .detach();
347 }
348
349 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
350 self
351 }
352}
353
354struct ClientDelegate {
355 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
356 cx: AsyncApp,
357}
358
359impl acp::Client for ClientDelegate {
360 async fn request_permission(
361 &self,
362 arguments: acp::RequestPermissionRequest,
363 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
364 let cx = &mut self.cx.clone();
365
366 let task = self
367 .session_thread(&arguments.session_id)?
368 .update(cx, |thread, cx| {
369 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
370 })??;
371
372 let outcome = task.await;
373
374 Ok(acp::RequestPermissionResponse { outcome })
375 }
376
377 async fn write_text_file(
378 &self,
379 arguments: acp::WriteTextFileRequest,
380 ) -> Result<(), acp::Error> {
381 let cx = &mut self.cx.clone();
382 let task = self
383 .session_thread(&arguments.session_id)?
384 .update(cx, |thread, cx| {
385 thread.write_text_file(arguments.path, arguments.content, cx)
386 })?;
387
388 task.await?;
389
390 Ok(())
391 }
392
393 async fn read_text_file(
394 &self,
395 arguments: acp::ReadTextFileRequest,
396 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
397 let task = self.session_thread(&arguments.session_id)?.update(
398 &mut self.cx.clone(),
399 |thread, cx| {
400 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
401 },
402 )?;
403
404 let content = task.await?;
405
406 Ok(acp::ReadTextFileResponse { content })
407 }
408
409 async fn session_notification(
410 &self,
411 notification: acp::SessionNotification,
412 ) -> Result<(), acp::Error> {
413 self.session_thread(¬ification.session_id)?
414 .update(&mut self.cx.clone(), |thread, cx| {
415 thread.handle_session_update(notification.update, cx)
416 })??;
417
418 Ok(())
419 }
420
421 async fn create_terminal(
422 &self,
423 args: acp::CreateTerminalRequest,
424 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
425 let terminal = self
426 .session_thread(&args.session_id)?
427 .update(&mut self.cx.clone(), |thread, cx| {
428 thread.create_terminal(
429 args.command,
430 args.args,
431 args.env,
432 args.cwd,
433 args.output_byte_limit,
434 cx,
435 )
436 })?
437 .await?;
438 Ok(
439 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
440 terminal_id: terminal.id().clone(),
441 })?,
442 )
443 }
444
445 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
446 self.session_thread(&args.session_id)?
447 .update(&mut self.cx.clone(), |thread, cx| {
448 thread.kill_terminal(args.terminal_id, cx)
449 })??;
450
451 Ok(())
452 }
453
454 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
455 self.session_thread(&args.session_id)?
456 .update(&mut self.cx.clone(), |thread, cx| {
457 thread.release_terminal(args.terminal_id, cx)
458 })??;
459
460 Ok(())
461 }
462
463 async fn terminal_output(
464 &self,
465 args: acp::TerminalOutputRequest,
466 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
467 self.session_thread(&args.session_id)?
468 .read_with(&mut self.cx.clone(), |thread, cx| {
469 let out = thread
470 .terminal(args.terminal_id)?
471 .read(cx)
472 .current_output(cx);
473
474 Ok(out)
475 })?
476 }
477
478 async fn wait_for_terminal_exit(
479 &self,
480 args: acp::WaitForTerminalExitRequest,
481 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
482 let exit_status = self
483 .session_thread(&args.session_id)?
484 .update(&mut self.cx.clone(), |thread, cx| {
485 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
486 })??
487 .await;
488
489 Ok(acp::WaitForTerminalExitResponse { exit_status })
490 }
491}
492
493impl ClientDelegate {
494 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
495 let sessions = self.sessions.borrow();
496 sessions
497 .get(session_id)
498 .context("Failed to get session")
499 .map(|session| session.thread.clone())
500 }
501}