1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use agent_settings::AgentSettings;
7use anyhow::anyhow;
8use collections::HashMap;
9use futures::AsyncBufReadExt as _;
10use futures::channel::oneshot;
11use futures::io::BufReader;
12use project::Project;
13use serde::Deserialize;
14use settings::Settings as _;
15use std::{any::Any, cell::RefCell};
16use std::{path::Path, rc::Rc};
17use thiserror::Error;
18
19use anyhow::{Context as _, Result};
20use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
21
22use acp_thread::{AcpThread, AuthRequired, LoadError};
23
24#[derive(Debug, Error)]
25#[error("Unsupported version")]
26pub struct UnsupportedVersion;
27
28pub struct AcpConnection {
29 server_name: SharedString,
30 connection: Rc<acp::ClientSideConnection>,
31 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
32 auth_methods: Vec<acp::AuthMethod>,
33 prompt_capabilities: acp::PromptCapabilities,
34 _io_task: Task<Result<()>>,
35 _wait_task: Task<Result<()>>,
36 _stderr_task: Task<Result<()>>,
37}
38
39pub struct AcpSession {
40 thread: WeakEntity<AcpThread>,
41 suppress_abort_err: bool,
42}
43
44pub async fn connect(
45 server_name: SharedString,
46 command: AgentServerCommand,
47 root_dir: &Path,
48 cx: &mut AsyncApp,
49) -> Result<Rc<dyn AgentConnection>> {
50 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
51 Ok(Rc::new(conn) as _)
52}
53
54const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
55
56impl AcpConnection {
57 pub async fn stdio(
58 server_name: SharedString,
59 command: AgentServerCommand,
60 root_dir: &Path,
61 cx: &mut AsyncApp,
62 ) -> Result<Self> {
63 let mut child = util::command::new_smol_command(command.path)
64 .args(command.args.iter().map(|arg| arg.as_str()))
65 .envs(command.env.iter().flatten())
66 .current_dir(root_dir)
67 .stdin(std::process::Stdio::piped())
68 .stdout(std::process::Stdio::piped())
69 .stderr(std::process::Stdio::piped())
70 .kill_on_drop(true)
71 .spawn()?;
72
73 let stdout = child.stdout.take().context("Failed to take stdout")?;
74 let stdin = child.stdin.take().context("Failed to take stdin")?;
75 let stderr = child.stderr.take().context("Failed to take stderr")?;
76 log::trace!("Spawned (pid: {})", child.id());
77
78 let sessions = Rc::new(RefCell::new(HashMap::default()));
79
80 let client = ClientDelegate {
81 sessions: sessions.clone(),
82 cx: cx.clone(),
83 };
84 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
85 let foreground_executor = cx.foreground_executor().clone();
86 move |fut| {
87 foreground_executor.spawn(fut).detach();
88 }
89 });
90
91 let io_task = cx.background_spawn(io_task);
92
93 let stderr_task = cx.background_spawn(async move {
94 let mut stderr = BufReader::new(stderr);
95 let mut line = String::new();
96 while let Ok(n) = stderr.read_line(&mut line).await
97 && n > 0
98 {
99 log::warn!("agent stderr: {}", &line);
100 line.clear();
101 }
102 Ok(())
103 });
104
105 let wait_task = cx.spawn({
106 let sessions = sessions.clone();
107 async move |cx| {
108 let status = child.status().await?;
109
110 for session in sessions.borrow().values() {
111 session
112 .thread
113 .update(cx, |thread, cx| {
114 thread.emit_load_error(LoadError::Exited { status }, cx)
115 })
116 .ok();
117 }
118
119 anyhow::Ok(())
120 }
121 });
122
123 let connection = Rc::new(connection);
124
125 cx.update(|cx| {
126 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
127 registry.set_active_connection(server_name.clone(), &connection, cx)
128 });
129 })?;
130
131 let response = connection
132 .initialize(acp::InitializeRequest {
133 protocol_version: acp::VERSION,
134 client_capabilities: acp::ClientCapabilities {
135 fs: acp::FileSystemCapability {
136 read_text_file: true,
137 write_text_file: true,
138 },
139 },
140 })
141 .await?;
142
143 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
144 return Err(UnsupportedVersion.into());
145 }
146
147 Ok(Self {
148 auth_methods: response.auth_methods,
149 connection,
150 server_name,
151 sessions,
152 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
153 _io_task: io_task,
154 _wait_task: wait_task,
155 _stderr_task: stderr_task,
156 })
157 }
158
159 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
160 &self.prompt_capabilities
161 }
162}
163
164impl AgentConnection for AcpConnection {
165 fn new_thread(
166 self: Rc<Self>,
167 project: Entity<Project>,
168 cwd: &Path,
169 cx: &mut App,
170 ) -> Task<Result<Entity<AcpThread>>> {
171 let conn = self.connection.clone();
172 let sessions = self.sessions.clone();
173 let cwd = cwd.to_path_buf();
174 let context_server_store = project.read(cx).context_server_store().read(cx);
175 let mcp_servers = context_server_store
176 .configured_server_ids()
177 .iter()
178 .filter_map(|id| {
179 let configuration = context_server_store.configuration_for_server(id)?;
180 let command = configuration.command();
181 Some(acp::McpServer {
182 name: id.0.to_string(),
183 command: command.path.clone(),
184 args: command.args.clone(),
185 env: if let Some(env) = command.env.as_ref() {
186 env.iter()
187 .map(|(name, value)| acp::EnvVariable {
188 name: name.clone(),
189 value: value.clone(),
190 })
191 .collect()
192 } else {
193 vec![]
194 },
195 })
196 })
197 .collect();
198
199 cx.spawn(async move |cx| {
200 let response = conn
201 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
202 .await
203 .map_err(|err| {
204 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
205 let mut error = AuthRequired::new();
206
207 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
208 error = error.with_description(err.message);
209 }
210
211 anyhow!(error)
212 } else {
213 anyhow!(err)
214 }
215 })?;
216
217 let session_id = response.session_id;
218 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
219 let thread = cx.new(|cx| {
220 AcpThread::new(
221 self.server_name.clone(),
222 self.clone(),
223 project,
224 action_log,
225 session_id.clone(),
226 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
227 watch::Receiver::constant(self.prompt_capabilities),
228 cx,
229 )
230 })?;
231
232 let session = AcpSession {
233 thread: thread.downgrade(),
234 suppress_abort_err: false,
235 };
236 sessions.borrow_mut().insert(session_id, session);
237
238 Ok(thread)
239 })
240 }
241
242 fn auth_methods(&self) -> &[acp::AuthMethod] {
243 &self.auth_methods
244 }
245
246 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
247 let conn = self.connection.clone();
248 cx.foreground_executor().spawn(async move {
249 let result = conn
250 .authenticate(acp::AuthenticateRequest {
251 method_id: method_id.clone(),
252 })
253 .await?;
254
255 Ok(result)
256 })
257 }
258
259 fn prompt(
260 &self,
261 _id: Option<acp_thread::UserMessageId>,
262 params: acp::PromptRequest,
263 cx: &mut App,
264 ) -> Task<Result<acp::PromptResponse>> {
265 let conn = self.connection.clone();
266 let sessions = self.sessions.clone();
267 let session_id = params.session_id.clone();
268 cx.foreground_executor().spawn(async move {
269 let result = conn.prompt(params).await;
270
271 let mut suppress_abort_err = false;
272
273 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
274 suppress_abort_err = session.suppress_abort_err;
275 session.suppress_abort_err = false;
276 }
277
278 match result {
279 Ok(response) => Ok(response),
280 Err(err) => {
281 if err.code != ErrorCode::INTERNAL_ERROR.code {
282 anyhow::bail!(err)
283 }
284
285 let Some(data) = &err.data else {
286 anyhow::bail!(err)
287 };
288
289 // Temporary workaround until the following PR is generally available:
290 // https://github.com/google-gemini/gemini-cli/pull/6656
291
292 #[derive(Deserialize)]
293 #[serde(deny_unknown_fields)]
294 struct ErrorDetails {
295 details: Box<str>,
296 }
297
298 match serde_json::from_value(data.clone()) {
299 Ok(ErrorDetails { details }) => {
300 if suppress_abort_err
301 && (details.contains("This operation was aborted")
302 || details.contains("The user aborted a request"))
303 {
304 Ok(acp::PromptResponse {
305 stop_reason: acp::StopReason::Cancelled,
306 })
307 } else {
308 Err(anyhow!(details))
309 }
310 }
311 Err(_) => Err(anyhow!(err)),
312 }
313 }
314 }
315 })
316 }
317
318 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
319 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
320 session.suppress_abort_err = true;
321 }
322 let conn = self.connection.clone();
323 let params = acp::CancelNotification {
324 session_id: session_id.clone(),
325 };
326 cx.foreground_executor()
327 .spawn(async move { conn.cancel(params).await })
328 .detach();
329 }
330
331 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
332 self
333 }
334}
335
336struct ClientDelegate {
337 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
338 cx: AsyncApp,
339}
340
341impl acp::Client for ClientDelegate {
342 async fn request_permission(
343 &self,
344 arguments: acp::RequestPermissionRequest,
345 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
346 let cx = &mut self.cx.clone();
347
348 // If always_allow_tool_actions is enabled, then auto-choose the first "Allow" button
349 if AgentSettings::try_read_global(cx, |settings| settings.always_allow_tool_actions)
350 .unwrap_or(false)
351 {
352 // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
353 // some tools would (incorrectly) continue to auto-accept.
354 if let Some(allow_once_option) = arguments.options.iter().find_map(|option| {
355 if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
356 Some(option.id.clone())
357 } else {
358 None
359 }
360 }) {
361 return Ok(acp::RequestPermissionResponse {
362 outcome: acp::RequestPermissionOutcome::Selected {
363 option_id: allow_once_option,
364 },
365 });
366 }
367 }
368
369 let rx = self
370 .sessions
371 .borrow()
372 .get(&arguments.session_id)
373 .context("Failed to get session")?
374 .thread
375 .update(cx, |thread, cx| {
376 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
377 })?;
378
379 let result = rx?.await;
380
381 let outcome = match result {
382 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
383 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
384 };
385
386 Ok(acp::RequestPermissionResponse { outcome })
387 }
388
389 async fn write_text_file(
390 &self,
391 arguments: acp::WriteTextFileRequest,
392 ) -> Result<(), acp::Error> {
393 let cx = &mut self.cx.clone();
394 let task = self
395 .sessions
396 .borrow()
397 .get(&arguments.session_id)
398 .context("Failed to get session")?
399 .thread
400 .update(cx, |thread, cx| {
401 thread.write_text_file(arguments.path, arguments.content, cx)
402 })?;
403
404 task.await?;
405
406 Ok(())
407 }
408
409 async fn read_text_file(
410 &self,
411 arguments: acp::ReadTextFileRequest,
412 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
413 let cx = &mut self.cx.clone();
414 let task = self
415 .sessions
416 .borrow()
417 .get(&arguments.session_id)
418 .context("Failed to get session")?
419 .thread
420 .update(cx, |thread, cx| {
421 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
422 })?;
423
424 let content = task.await?;
425
426 Ok(acp::ReadTextFileResponse { content })
427 }
428
429 async fn session_notification(
430 &self,
431 notification: acp::SessionNotification,
432 ) -> Result<(), acp::Error> {
433 let cx = &mut self.cx.clone();
434 let sessions = self.sessions.borrow();
435 let session = sessions
436 .get(¬ification.session_id)
437 .context("Failed to get session")?;
438
439 session.thread.update(cx, |thread, cx| {
440 thread.handle_session_update(notification.update, cx)
441 })??;
442
443 Ok(())
444 }
445}