1use crate::AgentServerCommand;
2use acp_thread::AgentConnection;
3use acp_tools::AcpConnectionRegistry;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
6use anyhow::anyhow;
7use collections::HashMap;
8use futures::AsyncBufReadExt as _;
9use futures::io::BufReader;
10use project::Project;
11use serde::Deserialize;
12
13use std::{any::Any, cell::RefCell};
14use std::{path::Path, rc::Rc};
15use thiserror::Error;
16
17use anyhow::{Context as _, Result};
18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
19
20use acp_thread::{AcpThread, AuthRequired, LoadError};
21
22#[derive(Debug, Error)]
23#[error("Unsupported version")]
24pub struct UnsupportedVersion;
25
26pub struct AcpConnection {
27 server_name: SharedString,
28 connection: Rc<acp::ClientSideConnection>,
29 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
30 auth_methods: Vec<acp::AuthMethod>,
31 prompt_capabilities: acp::PromptCapabilities,
32 _io_task: Task<Result<()>>,
33 _wait_task: Task<Result<()>>,
34 _stderr_task: Task<Result<()>>,
35}
36
37pub struct AcpSession {
38 thread: WeakEntity<AcpThread>,
39 suppress_abort_err: bool,
40}
41
42pub async fn connect(
43 server_name: SharedString,
44 command: AgentServerCommand,
45 root_dir: &Path,
46 cx: &mut AsyncApp,
47) -> Result<Rc<dyn AgentConnection>> {
48 let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
49 Ok(Rc::new(conn) as _)
50}
51
52const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
53
54impl AcpConnection {
55 pub async fn stdio(
56 server_name: SharedString,
57 command: AgentServerCommand,
58 root_dir: &Path,
59 cx: &mut AsyncApp,
60 ) -> Result<Self> {
61 let mut child = util::command::new_smol_command(command.path)
62 .args(command.args.iter().map(|arg| arg.as_str()))
63 .envs(command.env.iter().flatten())
64 .current_dir(root_dir)
65 .stdin(std::process::Stdio::piped())
66 .stdout(std::process::Stdio::piped())
67 .stderr(std::process::Stdio::piped())
68 .kill_on_drop(true)
69 .spawn()?;
70
71 let stdout = child.stdout.take().context("Failed to take stdout")?;
72 let stdin = child.stdin.take().context("Failed to take stdin")?;
73 let stderr = child.stderr.take().context("Failed to take stderr")?;
74 log::trace!("Spawned (pid: {})", child.id());
75
76 let sessions = Rc::new(RefCell::new(HashMap::default()));
77
78 let client = ClientDelegate {
79 sessions: sessions.clone(),
80 cx: cx.clone(),
81 };
82 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
83 let foreground_executor = cx.foreground_executor().clone();
84 move |fut| {
85 foreground_executor.spawn(fut).detach();
86 }
87 });
88
89 let io_task = cx.background_spawn(io_task);
90
91 let stderr_task = cx.background_spawn(async move {
92 let mut stderr = BufReader::new(stderr);
93 let mut line = String::new();
94 while let Ok(n) = stderr.read_line(&mut line).await
95 && n > 0
96 {
97 log::warn!("agent stderr: {}", &line);
98 line.clear();
99 }
100 Ok(())
101 });
102
103 let wait_task = cx.spawn({
104 let sessions = sessions.clone();
105 async move |cx| {
106 let status = child.status().await?;
107
108 for session in sessions.borrow().values() {
109 session
110 .thread
111 .update(cx, |thread, cx| {
112 thread.emit_load_error(LoadError::Exited { status }, cx)
113 })
114 .ok();
115 }
116
117 anyhow::Ok(())
118 }
119 });
120
121 let connection = Rc::new(connection);
122
123 cx.update(|cx| {
124 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125 registry.set_active_connection(server_name.clone(), &connection, cx)
126 });
127 })?;
128
129 let response = connection
130 .initialize(acp::InitializeRequest {
131 protocol_version: acp::VERSION,
132 client_capabilities: acp::ClientCapabilities {
133 fs: acp::FileSystemCapability {
134 read_text_file: true,
135 write_text_file: true,
136 },
137 },
138 })
139 .await?;
140
141 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
142 return Err(UnsupportedVersion.into());
143 }
144
145 Ok(Self {
146 auth_methods: response.auth_methods,
147 connection,
148 server_name,
149 sessions,
150 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
151 _io_task: io_task,
152 _wait_task: wait_task,
153 _stderr_task: stderr_task,
154 })
155 }
156
157 pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
158 &self.prompt_capabilities
159 }
160}
161
162impl AgentConnection for AcpConnection {
163 fn new_thread(
164 self: Rc<Self>,
165 project: Entity<Project>,
166 cwd: &Path,
167 cx: &mut App,
168 ) -> Task<Result<Entity<AcpThread>>> {
169 let conn = self.connection.clone();
170 let sessions = self.sessions.clone();
171 let cwd = cwd.to_path_buf();
172 let context_server_store = project.read(cx).context_server_store().read(cx);
173 let mcp_servers = context_server_store
174 .configured_server_ids()
175 .iter()
176 .filter_map(|id| {
177 let configuration = context_server_store.configuration_for_server(id)?;
178 let command = configuration.command();
179 Some(acp::McpServer {
180 name: id.0.to_string(),
181 command: command.path.clone(),
182 args: command.args.clone(),
183 env: if let Some(env) = command.env.as_ref() {
184 env.iter()
185 .map(|(name, value)| acp::EnvVariable {
186 name: name.clone(),
187 value: value.clone(),
188 })
189 .collect()
190 } else {
191 vec![]
192 },
193 })
194 })
195 .collect();
196
197 cx.spawn(async move |cx| {
198 let response = conn
199 .new_session(acp::NewSessionRequest { mcp_servers, cwd })
200 .await
201 .map_err(|err| {
202 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
203 let mut error = AuthRequired::new();
204
205 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
206 error = error.with_description(err.message);
207 }
208
209 anyhow!(error)
210 } else {
211 anyhow!(err)
212 }
213 })?;
214
215 let session_id = response.session_id;
216 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
217 let thread = cx.new(|cx| {
218 AcpThread::new(
219 self.server_name.clone(),
220 self.clone(),
221 project,
222 action_log,
223 session_id.clone(),
224 // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
225 watch::Receiver::constant(self.prompt_capabilities),
226 cx,
227 )
228 })?;
229
230 let session = AcpSession {
231 thread: thread.downgrade(),
232 suppress_abort_err: false,
233 };
234 sessions.borrow_mut().insert(session_id, session);
235
236 Ok(thread)
237 })
238 }
239
240 fn auth_methods(&self) -> &[acp::AuthMethod] {
241 &self.auth_methods
242 }
243
244 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
245 let conn = self.connection.clone();
246 cx.foreground_executor().spawn(async move {
247 let result = conn
248 .authenticate(acp::AuthenticateRequest {
249 method_id: method_id.clone(),
250 })
251 .await?;
252
253 Ok(result)
254 })
255 }
256
257 fn prompt(
258 &self,
259 _id: Option<acp_thread::UserMessageId>,
260 params: acp::PromptRequest,
261 cx: &mut App,
262 ) -> Task<Result<acp::PromptResponse>> {
263 let conn = self.connection.clone();
264 let sessions = self.sessions.clone();
265 let session_id = params.session_id.clone();
266 cx.foreground_executor().spawn(async move {
267 let result = conn.prompt(params).await;
268
269 let mut suppress_abort_err = false;
270
271 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
272 suppress_abort_err = session.suppress_abort_err;
273 session.suppress_abort_err = false;
274 }
275
276 match result {
277 Ok(response) => Ok(response),
278 Err(err) => {
279 if err.code != ErrorCode::INTERNAL_ERROR.code {
280 anyhow::bail!(err)
281 }
282
283 let Some(data) = &err.data else {
284 anyhow::bail!(err)
285 };
286
287 // Temporary workaround until the following PR is generally available:
288 // https://github.com/google-gemini/gemini-cli/pull/6656
289
290 #[derive(Deserialize)]
291 #[serde(deny_unknown_fields)]
292 struct ErrorDetails {
293 details: Box<str>,
294 }
295
296 match serde_json::from_value(data.clone()) {
297 Ok(ErrorDetails { details }) => {
298 if suppress_abort_err
299 && (details.contains("This operation was aborted")
300 || details.contains("The user aborted a request"))
301 {
302 Ok(acp::PromptResponse {
303 stop_reason: acp::StopReason::Cancelled,
304 })
305 } else {
306 Err(anyhow!(details))
307 }
308 }
309 Err(_) => Err(anyhow!(err)),
310 }
311 }
312 }
313 })
314 }
315
316 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
317 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
318 session.suppress_abort_err = true;
319 }
320 let conn = self.connection.clone();
321 let params = acp::CancelNotification {
322 session_id: session_id.clone(),
323 };
324 cx.foreground_executor()
325 .spawn(async move { conn.cancel(params).await })
326 .detach();
327 }
328
329 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
330 self
331 }
332}
333
334struct ClientDelegate {
335 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
336 cx: AsyncApp,
337}
338
339impl acp::Client for ClientDelegate {
340 async fn request_permission(
341 &self,
342 arguments: acp::RequestPermissionRequest,
343 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
344 let cx = &mut self.cx.clone();
345
346 let task = self
347 .sessions
348 .borrow()
349 .get(&arguments.session_id)
350 .context("Failed to get session")?
351 .thread
352 .update(cx, |thread, cx| {
353 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
354 })??;
355
356 let outcome = task.await;
357
358 Ok(acp::RequestPermissionResponse { outcome })
359 }
360
361 async fn write_text_file(
362 &self,
363 arguments: acp::WriteTextFileRequest,
364 ) -> Result<(), acp::Error> {
365 let cx = &mut self.cx.clone();
366 let task = self
367 .sessions
368 .borrow()
369 .get(&arguments.session_id)
370 .context("Failed to get session")?
371 .thread
372 .update(cx, |thread, cx| {
373 thread.write_text_file(arguments.path, arguments.content, cx)
374 })?;
375
376 task.await?;
377
378 Ok(())
379 }
380
381 async fn read_text_file(
382 &self,
383 arguments: acp::ReadTextFileRequest,
384 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
385 let cx = &mut self.cx.clone();
386 let task = self
387 .sessions
388 .borrow()
389 .get(&arguments.session_id)
390 .context("Failed to get session")?
391 .thread
392 .update(cx, |thread, cx| {
393 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
394 })?;
395
396 let content = task.await?;
397
398 Ok(acp::ReadTextFileResponse { content })
399 }
400
401 async fn session_notification(
402 &self,
403 notification: acp::SessionNotification,
404 ) -> Result<(), acp::Error> {
405 let cx = &mut self.cx.clone();
406 let sessions = self.sessions.borrow();
407 let session = sessions
408 .get(¬ification.session_id)
409 .context("Failed to get session")?;
410
411 session.thread.update(cx, |thread, cx| {
412 thread.handle_session_update(notification.update, cx)
413 })??;
414
415 Ok(())
416 }
417}