1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::AsyncBufReadExt as _;
9use futures::io::BufReader;
10use project::Project;
11use serde::Deserialize;
12
13use std::{any::Any, cell::RefCell};
14use std::{path::Path, rc::Rc};
15use thiserror::Error;
16
17use anyhow::{Context as _, Result};
18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
19
20use acp_thread::{AcpThread, AuthRequired, LoadError};
21
22#[derive(Debug, Error)]
23#[error("Unsupported version")]
24pub struct UnsupportedVersion;
25
26pub struct AcpConnection {
27 server_name: SharedString,
28 connection: Rc<acp::ClientSideConnection>,
29 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
30 auth_methods: Vec<acp::AuthMethod>,
31 agent_capabilities: acp::AgentCapabilities,
32 _io_task: Task<Result<()>>,
33 _wait_task: Task<Result<()>>,
34 _stderr_task: Task<Result<()>>,
35}
36
37pub struct AcpSession {
38 thread: WeakEntity<AcpThread>,
39 suppress_abort_err: bool,
40}
41
42pub async fn connect(
43 server_name: SharedString,
44 command: AgentServerCommand,
45 root_dir: &Path,
46 cx: &mut AsyncApp,
47) -> Result<Rc<dyn AgentConnection>> {
48 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
49 Ok(Rc::new(conn) as _)
50}
51
52const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
53
54impl AcpConnection {
55 pub async fn stdio(
56 server_name: SharedString,
57 command: AgentServerCommand,
58 root_dir: &Path,
59 cx: &mut AsyncApp,
60 ) -> Result<Self> {
61 let mut child = util::command::new_smol_command(command.path)
62 .args(command.args.iter().map(|arg| arg.as_str()))
63 .envs(command.env.iter().flatten())
64 .current_dir(root_dir)
65 .stdin(std::process::Stdio::piped())
66 .stdout(std::process::Stdio::piped())
67 .stderr(std::process::Stdio::piped())
68 .kill_on_drop(true)
69 .spawn()?;
70
71 let stdout = child.stdout.take().context("Failed to take stdout")?;
72 let stdin = child.stdin.take().context("Failed to take stdin")?;
73 let stderr = child.stderr.take().context("Failed to take stderr")?;
74 log::trace!("Spawned (pid: {})", child.id());
75
76 let sessions = Rc::new(RefCell::new(HashMap::default()));
77
78 let client = ClientDelegate {
79 sessions: sessions.clone(),
80 cx: cx.clone(),
81 };
82 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
83 let foreground_executor = cx.foreground_executor().clone();
84 move |fut| {
85 foreground_executor.spawn(fut).detach();
86 }
87 });
88
89 let io_task = cx.background_spawn(io_task);
90
91 let stderr_task = cx.background_spawn(async move {
92 let mut stderr = BufReader::new(stderr);
93 let mut line = String::new();
94 while let Ok(n) = stderr.read_line(&mut line).await
95 && n > 0
96 {
97 log::warn!("agent stderr: {}", &line);
98 line.clear();
99 }
100 Ok(())
101 });
102
103 let wait_task = cx.spawn({
104 let sessions = sessions.clone();
105 async move |cx| {
106 let status = child.status().await?;
107
108 for session in sessions.borrow().values() {
109 session
110 .thread
111 .update(cx, |thread, cx| {
112 thread.emit_load_error(LoadError::Exited { status }, cx)
113 })
114 .ok();
115 }
116
117 anyhow::Ok(())
118 }
119 });
120
121 let connection = Rc::new(connection);
122
123 cx.update(|cx| {
124 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125 registry.set_active_connection(server_name.clone(), &connection, cx)
126 });
127 })?;
128
129 let response = connection
130 .initialize(acp::InitializeRequest {
131 protocol_version: acp::VERSION,
132 client_capabilities: acp::ClientCapabilities {
133 fs: acp::FileSystemCapability {
134 read_text_file: true,
135 write_text_file: true,
136 },
137 terminal: true,
138 },
139 })
140 .await?;
141
142 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
143 return Err(UnsupportedVersion.into());
144 }
145
146 Ok(Self {
147 auth_methods: response.auth_methods,
148 connection,
149 server_name,
150 sessions,
151 agent_capabilities: response.agent_capabilities,
152 _io_task: io_task,
153 _wait_task: wait_task,
154 _stderr_task: stderr_task,
155 })
156 }
157
158 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
159 &self.agent_capabilities.prompt_capabilities
160 }
161}
162
163impl AgentConnection for AcpConnection {
164 fn new_thread(
165 self: Rc<Self>,
166 project: Entity<Project>,
167 cwd: &Path,
168 cx: &mut App,
169 ) -> Task<Result<Entity<AcpThread>>> {
170 let conn = self.connection.clone();
171 let sessions = self.sessions.clone();
172 let cwd = cwd.to_path_buf();
173 let context_server_store = project.read(cx).context_server_store().read(cx);
174 let mcp_servers = context_server_store
175 .configured_server_ids()
176 .iter()
177 .filter_map(|id| {
178 let configuration = context_server_store.configuration_for_server(id)?;
179 let command = configuration.command();
180 Some(acp::McpServer {
181 name: id.0.to_string(),
182 command: command.path.clone(),
183 args: command.args.clone(),
184 env: if let Some(env) = command.env.as_ref() {
185 env.iter()
186 .map(|(name, value)| acp::EnvVariable {
187 name: name.clone(),
188 value: value.clone(),
189 })
190 .collect()
191 } else {
192 vec![]
193 },
194 })
195 })
196 .collect();
197
198 cx.spawn(async move |cx| {
199 let response = conn
200 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
201 .await
202 .map_err(|err| {
203 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
204 let mut error = AuthRequired::new();
205
206 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
207 error = error.with_description(err.message);
208 }
209
210 anyhow!(error)
211 } else {
212 anyhow!(err)
213 }
214 })?;
215
216 let session_id = response.session_id;
217 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
218 let thread = cx.new(|cx| {
219 AcpThread::new(
220 self.server_name.clone(),
221 self.clone(),
222 project,
223 action_log,
224 session_id.clone(),
225 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
226 watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
227 cx,
228 )
229 })?;
230
231 let session = AcpSession {
232 thread: thread.downgrade(),
233 suppress_abort_err: false,
234 };
235 sessions.borrow_mut().insert(session_id, session);
236
237 Ok(thread)
238 })
239 }
240
241 fn auth_methods(&self) -> &[acp::AuthMethod] {
242 &self.auth_methods
243 }
244
245 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
246 let conn = self.connection.clone();
247 cx.foreground_executor().spawn(async move {
248 let result = conn
249 .authenticate(acp::AuthenticateRequest {
250 method_id: method_id.clone(),
251 })
252 .await?;
253
254 Ok(result)
255 })
256 }
257
258 fn prompt(
259 &self,
260 _id: Option<acp_thread::UserMessageId>,
261 params: acp::PromptRequest,
262 cx: &mut App,
263 ) -> Task<Result<acp::PromptResponse>> {
264 let conn = self.connection.clone();
265 let sessions = self.sessions.clone();
266 let session_id = params.session_id.clone();
267 cx.foreground_executor().spawn(async move {
268 let result = conn.prompt(params).await;
269
270 let mut suppress_abort_err = false;
271
272 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
273 suppress_abort_err = session.suppress_abort_err;
274 session.suppress_abort_err = false;
275 }
276
277 match result {
278 Ok(response) => Ok(response),
279 Err(err) => {
280 if err.code != ErrorCode::INTERNAL_ERROR.code {
281 anyhow::bail!(err)
282 }
283
284 let Some(data) = &err.data else {
285 anyhow::bail!(err)
286 };
287
288 // Temporary workaround until the following PR is generally available:
289 // https://github.com/google-gemini/gemini-cli/pull/6656
290
291 #[derive(Deserialize)]
292 #[serde(deny_unknown_fields)]
293 struct ErrorDetails {
294 details: Box<str>,
295 }
296
297 match serde_json::from_value(data.clone()) {
298 Ok(ErrorDetails { details }) => {
299 if suppress_abort_err
300 && (details.contains("This operation was aborted")
301 || details.contains("The user aborted a request"))
302 {
303 Ok(acp::PromptResponse {
304 stop_reason: acp::StopReason::Cancelled,
305 })
306 } else {
307 Err(anyhow!(details))
308 }
309 }
310 Err(_) => Err(anyhow!(err)),
311 }
312 }
313 }
314 })
315 }
316
317 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
318 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
319 session.suppress_abort_err = true;
320 }
321 let conn = self.connection.clone();
322 let params = acp::CancelNotification {
323 session_id: session_id.clone(),
324 };
325 cx.foreground_executor()
326 .spawn(async move { conn.cancel(params).await })
327 .detach();
328 }
329
330 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
331 self
332 }
333}
334
335struct ClientDelegate {
336 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
337 cx: AsyncApp,
338}
339
340impl acp::Client for ClientDelegate {
341 async fn request_permission(
342 &self,
343 arguments: acp::RequestPermissionRequest,
344 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
345 let cx = &mut self.cx.clone();
346
347 let task = self
348 .session_thread(&arguments.session_id)?
349 .update(cx, |thread, cx| {
350 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
351 })??;
352
353 let outcome = task.await;
354
355 Ok(acp::RequestPermissionResponse { outcome })
356 }
357
358 async fn write_text_file(
359 &self,
360 arguments: acp::WriteTextFileRequest,
361 ) -> Result<(), acp::Error> {
362 let cx = &mut self.cx.clone();
363 let task = self
364 .session_thread(&arguments.session_id)?
365 .update(cx, |thread, cx| {
366 thread.write_text_file(arguments.path, arguments.content, cx)
367 })?;
368
369 task.await?;
370
371 Ok(())
372 }
373
374 async fn read_text_file(
375 &self,
376 arguments: acp::ReadTextFileRequest,
377 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
378 let task = self.session_thread(&arguments.session_id)?.update(
379 &mut self.cx.clone(),
380 |thread, cx| {
381 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
382 },
383 )?;
384
385 let content = task.await?;
386
387 Ok(acp::ReadTextFileResponse { content })
388 }
389
390 async fn session_notification(
391 &self,
392 notification: acp::SessionNotification,
393 ) -> Result<(), acp::Error> {
394 self.session_thread(¬ification.session_id)?
395 .update(&mut self.cx.clone(), |thread, cx| {
396 thread.handle_session_update(notification.update, cx)
397 })??;
398
399 Ok(())
400 }
401
402 async fn create_terminal(
403 &self,
404 args: acp::CreateTerminalRequest,
405 ) -> Result<acp::CreateTerminalResponse, acp::Error> {
406 let terminal = self
407 .session_thread(&args.session_id)?
408 .update(&mut self.cx.clone(), |thread, cx| {
409 thread.create_terminal(
410 args.command,
411 args.args,
412 args.env,
413 args.cwd,
414 args.output_byte_limit,
415 cx,
416 )
417 })?
418 .await?;
419 Ok(
420 terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
421 terminal_id: terminal.id().clone(),
422 })?,
423 )
424 }
425
426 async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
427 self.session_thread(&args.session_id)?
428 .update(&mut self.cx.clone(), |thread, cx| {
429 thread.kill_terminal(args.terminal_id, cx)
430 })??;
431
432 Ok(())
433 }
434
435 async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
436 self.session_thread(&args.session_id)?
437 .update(&mut self.cx.clone(), |thread, cx| {
438 thread.release_terminal(args.terminal_id, cx)
439 })??;
440
441 Ok(())
442 }
443
444 async fn terminal_output(
445 &self,
446 args: acp::TerminalOutputRequest,
447 ) -> Result<acp::TerminalOutputResponse, acp::Error> {
448 self.session_thread(&args.session_id)?
449 .read_with(&mut self.cx.clone(), |thread, cx| {
450 let out = thread
451 .terminal(args.terminal_id)?
452 .read(cx)
453 .current_output(cx);
454
455 Ok(out)
456 })?
457 }
458
459 async fn wait_for_terminal_exit(
460 &self,
461 args: acp::WaitForTerminalExitRequest,
462 ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
463 let exit_status = self
464 .session_thread(&args.session_id)?
465 .update(&mut self.cx.clone(), |thread, cx| {
466 anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
467 })??
468 .await;
469
470 Ok(acp::WaitForTerminalExitResponse { exit_status })
471 }
472}
473
474impl ClientDelegate {
475 fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
476 let sessions = self.sessions.borrow();
477 sessions
478 .get(session_id)
479 .context("Failed to get session")
480 .map(|session| session.thread.clone())
481 }
482}