1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::AsyncBufReadExt as _;
9use futures::io::BufReader;
10use project::Project;
11use serde::Deserialize;
12use util::ResultExt as _;
13
14use std::{any::Any, cell::RefCell};
15use std::{path::Path, rc::Rc};
16use thiserror::Error;
17
18use anyhow::{Context as _, Result};
19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
20
21use acp_thread::{AcpThread, AuthRequired, LoadError};
22
23#[derive(Debug, Error)]
24#[error("Unsupported version")]
25pub struct UnsupportedVersion;
26
27pub struct AcpConnection {
28 server_name: SharedString,
29 connection: Rc<acp::ClientSideConnection>,
30 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
31 auth_methods: Vec<acp::AuthMethod>,
32 agent_capabilities: acp::AgentCapabilities,
33 default_mode: Option<acp::SessionModeId>,
34 root_dir: PathBuf,
35 // NB: Don't move this into the wait_task, since we need to ensure the process is
36 // killed on drop (setting kill_on_drop on the command seems to not always work).
37 child: smol::process::Child,
38 _io_task: Task<Result<()>>,
39 _wait_task: Task<Result<()>>,
40 _stderr_task: Task<Result<()>>,
41}
42
43pub struct AcpSession {
44 thread: WeakEntity<AcpThread>,
45 suppress_abort_err: bool,
46}
47
48pub async fn connect(
49 server_name: SharedString,
50 command: AgentServerCommand,
51 root_dir: &Path,
52 cx: &mut AsyncApp,
53) -> Result<Rc<dyn AgentConnection>> {
54 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
55 Ok(Rc::new(conn) as _)
56}
57
58const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
59
60impl AcpConnection {
61 pub async fn stdio(
62 server_name: SharedString,
63 command: AgentServerCommand,
64 root_dir: &Path,
65 cx: &mut AsyncApp,
66 ) -> Result<Self> {
67 let mut child = util::command::new_smol_command(command.path)
68 .args(command.args.iter().map(|arg| arg.as_str()))
69 .envs(command.env.iter().flatten())
70 .current_dir(root_dir)
71 .stdin(std::process::Stdio::piped())
72 .stdout(std::process::Stdio::piped())
73 .stderr(std::process::Stdio::piped());
74 if !is_remote {
75 child.current_dir(root_dir);
76 }
77 let mut child = child.spawn()?;
78
79 let stdout = child.stdout.take().context("Failed to take stdout")?;
80 let stdin = child.stdin.take().context("Failed to take stdin")?;
81 let stderr = child.stderr.take().context("Failed to take stderr")?;
82 log::trace!("Spawned (pid: {})", child.id());
83
84 let sessions = Rc::new(RefCell::new(HashMap::default()));
85
86 let client = ClientDelegate {
87 sessions: sessions.clone(),
88 cx: cx.clone(),
89 };
90 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
91 let foreground_executor = cx.foreground_executor().clone();
92 move |fut| {
93 foreground_executor.spawn(fut).detach();
94 }
95 });
96
97 let io_task = cx.background_spawn(io_task);
98
99 let stderr_task = cx.background_spawn(async move {
100 let mut stderr = BufReader::new(stderr);
101 let mut line = String::new();
102 while let Ok(n) = stderr.read_line(&mut line).await
103 && n > 0
104 {
105 log::warn!("agent stderr: {}", &line);
106 line.clear();
107 }
108 Ok(())
109 });
110
111 let wait_task = cx.spawn({
112 let sessions = sessions.clone();
113 let status_fut = child.status();
114 async move |cx| {
115 let status = status_fut.await?;
116
117 for session in sessions.borrow().values() {
118 session
119 .thread
120 .update(cx, |thread, cx| {
121 thread.emit_load_error(LoadError::Exited { status }, cx)
122 })
123 .ok();
124 }
125
126 anyhow::Ok(())
127 }
128 });
129
130 let connection = Rc::new(connection);
131
132 cx.update(|cx| {
133 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
134 registry.set_active_connection(server_name.clone(), &connection, cx)
135 });
136 })?;
137
138 let response = connection
139 .initialize(acp::InitializeRequest {
140 protocol_version: acp::VERSION,
141 client_capabilities: acp::ClientCapabilities {
142 fs: acp::FileSystemCapability {
143 read_text_file: true,
144 write_text_file: true,
145 },
146 terminal: true,
147 },
148 })
149 .await?;
150
151 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
152 return Err(UnsupportedVersion.into());
153 }
154
155 Ok(Self {
156 auth_methods: response.auth_methods,
157 connection,
158 server_name,
159 sessions,
160 agent_capabilities: response.agent_capabilities,
161 _io_task: io_task,
162 _wait_task: wait_task,
163 _stderr_task: stderr_task,
164 child,
165 })
166 }
167
168 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
169 &self.agent_capabilities.prompt_capabilities
170 }
171}
172
173impl Drop for AcpConnection {
174 fn drop(&mut self) {
175 // See the comment on the child field.
176 self.child.kill().log_err();
177 }
178}
179
180impl AgentConnection for AcpConnection {
181 fn new_thread(
182 self: Rc<Self>,
183 project: Entity<Project>,
184 cwd: &Path,
185 cx: &mut App,
186 ) -> Task<Result<Entity<AcpThread>>> {
187 let conn = self.connection.clone();
188 let sessions = self.sessions.clone();
189 let cwd = cwd.to_path_buf();
190 let context_server_store = project.read(cx).context_server_store().read(cx);
191 let mcp_servers = context_server_store
192 .configured_server_ids()
193 .iter()
194 .filter_map(|id| {
195 let configuration = context_server_store.configuration_for_server(id)?;
196 let command = configuration.command();
197 Some(acp::McpServer {
198 name: id.0.to_string(),
199 command: command.path.clone(),
200 args: command.args.clone(),
201 env: if let Some(env) = command.env.as_ref() {
202 env.iter()
203 .map(|(name, value)| acp::EnvVariable {
204 name: name.clone(),
205 value: value.clone(),
206 })
207 .collect()
208 } else {
209 vec![]
210 },
211 })
212 })
213 .collect();
214
215 cx.spawn(async move |cx| {
216 let response = conn
217 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
218 .await
219 .map_err(|err| {
220 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
221 let mut error = AuthRequired::new();
222
223 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
224 error = error.with_description(err.message);
225 }
226
227 anyhow!(error)
228 } else {
229 anyhow!(err)
230 }
231 })?;
232
233 let session_id = response.session_id;
234 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
235 let thread = cx.new(|cx| {
236 AcpThread::new(
237 self.server_name.clone(),
238 self.clone(),
239 project,
240 action_log,
241 session_id.clone(),
242 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
243 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
244 cx,
245 )
246 })?;
247
248 let session = AcpSession {
249 thread: thread.downgrade(),
250 suppress_abort_err: false,
251 };
252 sessions.borrow_mut().insert(session_id, session);
253
254 Ok(thread)
255 })
256 }
257
258 fn auth_methods(&self) -> &[acp::AuthMethod] {
259 &self.auth_methods
260 }
261
262 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
263 let conn = self.connection.clone();
264 cx.foreground_executor().spawn(async move {
265 let result = conn
266 .authenticate(acp::AuthenticateRequest {
267 method_id: method_id.clone(),
268 })
269 .await?;
270
271 Ok(result)
272 })
273 }
274
275 fn prompt(
276 &self,
277 _id: Option<acp_thread::UserMessageId>,
278 params: acp::PromptRequest,
279 cx: &mut App,
280 ) -> Task<Result<acp::PromptResponse>> {
281 let conn = self.connection.clone();
282 let sessions = self.sessions.clone();
283 let session_id = params.session_id.clone();
284 cx.foreground_executor().spawn(async move {
285 let result = conn.prompt(params).await;
286
287 let mut suppress_abort_err = false;
288
289 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
290 suppress_abort_err = session.suppress_abort_err;
291 session.suppress_abort_err = false;
292 }
293
294 match result {
295 Ok(response) => Ok(response),
296 Err(err) => {
297 if err.code != ErrorCode::INTERNAL_ERROR.code {
298 anyhow::bail!(err)
299 }
300
301 let Some(data) = &err.data else {
302 anyhow::bail!(err)
303 };
304
305 // Temporary workaround until the following PR is generally available:
306 // https://github.com/google-gemini/gemini-cli/pull/6656
307
308 #[derive(Deserialize)]
309 #[serde(deny_unknown_fields)]
310 struct ErrorDetails {
311 details: Box<str>,
312 }
313
314 match serde_json::from_value(data.clone()) {
315 Ok(ErrorDetails { details }) => {
316 if suppress_abort_err
317 && (details.contains("This operation was aborted")
318 || details.contains("The user aborted a request"))
319 {
320 Ok(acp::PromptResponse {
321 stop_reason: acp::StopReason::Cancelled,
322 })
323 } else {
324 Err(anyhow!(details))
325 }
326 }
327 Err(_) => Err(anyhow!(err)),
328 }
329 }
330 }
331 })
332 }
333
334 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
335 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
336 session.suppress_abort_err = true;
337 }
338 let conn = self.connection.clone();
339 let params = acp::CancelNotification {
340 session_id: session_id.clone(),
341 };
342 cx.foreground_executor()
343 .spawn(async move { conn.cancel(params).await })
344 .detach();
345 }
346
347 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
348 self
349 }
350}
351
352struct ClientDelegate {
353 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
354 cx: AsyncApp,
355}
356
357impl acp::Client for ClientDelegate {
358 async fn request_permission(
359 &self,
360 arguments: acp::RequestPermissionRequest,
361 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
362 let cx = &mut self.cx.clone();
363
364 let task = self
365 .session_thread(&arguments.session_id)?
366 .update(cx, |thread, cx| {
367 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
368 })??;
369
370 let outcome = task.await;
371
372 Ok(acp::RequestPermissionResponse { outcome })
373 }
374
375 async fn write_text_file(
376 &self,
377 arguments: acp::WriteTextFileRequest,
378 ) -> Result<(), acp::Error> {
379 let cx = &mut self.cx.clone();
380 let task = self
381 .session_thread(&arguments.session_id)?
382 .update(cx, |thread, cx| {
383 thread.write_text_file(arguments.path, arguments.content, cx)
384 })?;
385
386 task.await?;
387
388 Ok(())
389 }
390
391 async fn read_text_file(
392 &self,
393 arguments: acp::ReadTextFileRequest,
394 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
395 let task = self.session_thread(&arguments.session_id)?.update(
396 &mut self.cx.clone(),
397 |thread, cx| {
398 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
399 },
400 )?;
401
402 let content = task.await?;
403
404 Ok(acp::ReadTextFileResponse { content })
405 }
406
407 async fn session_notification(
408 &self,
409 notification: acp::SessionNotification,
410 ) -> Result<(), acp::Error> {
411 self.session_thread(¬ification.session_id)?
412 .update(&mut self.cx.clone(), |thread, cx| {
413 thread.handle_session_update(notification.update, cx)
414 })??;
415
416 Ok(())
417 }
418
419 async fn create_terminal(
420 &self,
421 args: acp::CreateTerminalRequest,
422 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
423 let terminal = self
424 .session_thread(&args.session_id)?
425 .update(&mut self.cx.clone(), |thread, cx| {
426 thread.create_terminal(
427 args.command,
428 args.args,
429 args.env,
430 args.cwd,
431 args.output_byte_limit,
432 cx,
433 )
434 })?
435 .await?;
436 Ok(
437 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
438 terminal_id: terminal.id().clone(),
439 })?,
440 )
441 }
442
443 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
444 self.session_thread(&args.session_id)?
445 .update(&mut self.cx.clone(), |thread, cx| {
446 thread.kill_terminal(args.terminal_id, cx)
447 })??;
448
449 Ok(())
450 }
451
452 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
453 self.session_thread(&args.session_id)?
454 .update(&mut self.cx.clone(), |thread, cx| {
455 thread.release_terminal(args.terminal_id, cx)
456 })??;
457
458 Ok(())
459 }
460
461 async fn terminal_output(
462 &self,
463 args: acp::TerminalOutputRequest,
464 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
465 self.session_thread(&args.session_id)?
466 .read_with(&mut self.cx.clone(), |thread, cx| {
467 let out = thread
468 .terminal(args.terminal_id)?
469 .read(cx)
470 .current_output(cx);
471
472 Ok(out)
473 })?
474 }
475
476 async fn wait_for_terminal_exit(
477 &self,
478 args: acp::WaitForTerminalExitRequest,
479 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
480 let exit_status = self
481 .session_thread(&args.session_id)?
482 .update(&mut self.cx.clone(), |thread, cx| {
483 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
484 })??
485 .await;
486
487 Ok(acp::WaitForTerminalExitResponse { exit_status })
488 }
489}
490
491impl ClientDelegate {
492 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
493 let sessions = self.sessions.borrow();
494 sessions
495 .get(session_id)
496 .context("Failed to get session")
497 .map(|session| session.thread.clone())
498 }
499}