1use acp_tools::AcpConnectionRegistry;
2use action_log::ActionLog;
3use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
4use anyhow::anyhow;
5use collections::HashMap;
6use futures::AsyncBufReadExt as _;
7use futures::channel::oneshot;
8use futures::io::BufReader;
9use project::Project;
10use serde::Deserialize;
11use std::path::Path;
12use std::rc::Rc;
13use std::{any::Any, cell::RefCell};
14
15use anyhow::{Context as _, Result};
16use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
17
18use crate::{AgentServerCommand, acp::UnsupportedVersion};
19use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
20
21pub struct AcpConnection {
22 server_name: &'static str,
23 connection: Rc<acp::ClientSideConnection>,
24 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
25 auth_methods: Vec<acp::AuthMethod>,
26 prompt_capabilities: acp::PromptCapabilities,
27 _io_task: Task<Result<()>>,
28}
29
30pub struct AcpSession {
31 thread: WeakEntity<AcpThread>,
32 suppress_abort_err: bool,
33}
34
35const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
36
37impl AcpConnection {
38 pub async fn stdio(
39 server_name: &'static str,
40 command: AgentServerCommand,
41 root_dir: &Path,
42 cx: &mut AsyncApp,
43 ) -> Result<Self> {
44 let mut child = util::command::new_smol_command(&command.path)
45 .args(command.args.iter().map(|arg| arg.as_str()))
46 .envs(command.env.iter().flatten())
47 .current_dir(root_dir)
48 .stdin(std::process::Stdio::piped())
49 .stdout(std::process::Stdio::piped())
50 .stderr(std::process::Stdio::piped())
51 .kill_on_drop(true)
52 .spawn()?;
53
54 let stdout = child.stdout.take().context("Failed to take stdout")?;
55 let stdin = child.stdin.take().context("Failed to take stdin")?;
56 let stderr = child.stderr.take().context("Failed to take stderr")?;
57 log::trace!("Spawned (pid: {})", child.id());
58
59 let sessions = Rc::new(RefCell::new(HashMap::default()));
60
61 let client = ClientDelegate {
62 sessions: sessions.clone(),
63 cx: cx.clone(),
64 };
65 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
66 let foreground_executor = cx.foreground_executor().clone();
67 move |fut| {
68 foreground_executor.spawn(fut).detach();
69 }
70 });
71
72 let io_task = cx.background_spawn(io_task);
73
74 cx.background_spawn(async move {
75 let mut stderr = BufReader::new(stderr);
76 let mut line = String::new();
77 while let Ok(n) = stderr.read_line(&mut line).await
78 && n > 0
79 {
80 log::warn!("agent stderr: {}", &line);
81 line.clear();
82 }
83 })
84 .detach();
85
86 cx.spawn({
87 let sessions = sessions.clone();
88 async move |cx| {
89 let status = child.status().await?;
90
91 for session in sessions.borrow().values() {
92 session
93 .thread
94 .update(cx, |thread, cx| {
95 thread.emit_load_error(LoadError::Exited { status }, cx)
96 })
97 .ok();
98 }
99
100 anyhow::Ok(())
101 }
102 })
103 .detach();
104
105 let connection = Rc::new(connection);
106
107 cx.update(|cx| {
108 AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
109 registry.set_active_connection(server_name, &connection, cx)
110 });
111 })?;
112
113 let response = connection
114 .initialize(acp::InitializeRequest {
115 protocol_version: acp::VERSION,
116 client_capabilities: acp::ClientCapabilities {
117 fs: acp::FileSystemCapability {
118 read_text_file: true,
119 write_text_file: true,
120 },
121 },
122 })
123 .await?;
124
125 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
126 return Err(UnsupportedVersion.into());
127 }
128
129 Ok(Self {
130 auth_methods: response.auth_methods,
131 connection,
132 server_name,
133 sessions,
134 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
135 _io_task: io_task,
136 })
137 }
138}
139
140impl AgentConnection for AcpConnection {
141 fn new_thread(
142 self: Rc<Self>,
143 project: Entity<Project>,
144 cwd: &Path,
145 cx: &mut App,
146 ) -> Task<Result<Entity<AcpThread>>> {
147 let conn = self.connection.clone();
148 let sessions = self.sessions.clone();
149 let cwd = cwd.to_path_buf();
150 cx.spawn(async move |cx| {
151 let response = conn
152 .new_session(acp::NewSessionRequest {
153 mcp_servers: vec![],
154 cwd,
155 })
156 .await
157 .map_err(|err| {
158 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
159 let mut error = AuthRequired::new();
160
161 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
162 error = error.with_description(err.message);
163 }
164
165 anyhow!(error)
166 } else {
167 anyhow!(err)
168 }
169 })?;
170
171 let session_id = response.session_id;
172 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
173 let thread = cx.new(|_cx| {
174 AcpThread::new(
175 self.server_name,
176 self.clone(),
177 project,
178 action_log,
179 session_id.clone(),
180 )
181 })?;
182
183 let session = AcpSession {
184 thread: thread.downgrade(),
185 suppress_abort_err: false,
186 };
187 sessions.borrow_mut().insert(session_id, session);
188
189 Ok(thread)
190 })
191 }
192
193 fn auth_methods(&self) -> &[acp::AuthMethod] {
194 &self.auth_methods
195 }
196
197 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
198 let conn = self.connection.clone();
199 cx.foreground_executor().spawn(async move {
200 let result = conn
201 .authenticate(acp::AuthenticateRequest {
202 method_id: method_id.clone(),
203 })
204 .await?;
205
206 Ok(result)
207 })
208 }
209
210 fn prompt(
211 &self,
212 _id: Option<acp_thread::UserMessageId>,
213 params: acp::PromptRequest,
214 cx: &mut App,
215 ) -> Task<Result<acp::PromptResponse>> {
216 let conn = self.connection.clone();
217 let sessions = self.sessions.clone();
218 let session_id = params.session_id.clone();
219 cx.foreground_executor().spawn(async move {
220 let result = conn.prompt(params).await;
221
222 let mut suppress_abort_err = false;
223
224 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
225 suppress_abort_err = session.suppress_abort_err;
226 session.suppress_abort_err = false;
227 }
228
229 match result {
230 Ok(response) => Ok(response),
231 Err(err) => {
232 if err.code != ErrorCode::INTERNAL_ERROR.code {
233 anyhow::bail!(err)
234 }
235
236 let Some(data) = &err.data else {
237 anyhow::bail!(err)
238 };
239
240 // Temporary workaround until the following PR is generally available:
241 // https://github.com/google-gemini/gemini-cli/pull/6656
242
243 #[derive(Deserialize)]
244 #[serde(deny_unknown_fields)]
245 struct ErrorDetails {
246 details: Box<str>,
247 }
248
249 match serde_json::from_value(data.clone()) {
250 Ok(ErrorDetails { details }) => {
251 if suppress_abort_err && details.contains("This operation was aborted")
252 {
253 Ok(acp::PromptResponse {
254 stop_reason: acp::StopReason::Cancelled,
255 })
256 } else {
257 Err(anyhow!(details))
258 }
259 }
260 Err(_) => Err(anyhow!(err)),
261 }
262 }
263 }
264 })
265 }
266
267 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
268 self.prompt_capabilities
269 }
270
271 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
272 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
273 session.suppress_abort_err = true;
274 }
275 let conn = self.connection.clone();
276 let params = acp::CancelNotification {
277 session_id: session_id.clone(),
278 };
279 cx.foreground_executor()
280 .spawn(async move { conn.cancel(params).await })
281 .detach();
282 }
283
284 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
285 self
286 }
287}
288
289struct ClientDelegate {
290 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
291 cx: AsyncApp,
292}
293
294impl acp::Client for ClientDelegate {
295 async fn request_permission(
296 &self,
297 arguments: acp::RequestPermissionRequest,
298 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
299 let cx = &mut self.cx.clone();
300 let rx = self
301 .sessions
302 .borrow()
303 .get(&arguments.session_id)
304 .context("Failed to get session")?
305 .thread
306 .update(cx, |thread, cx| {
307 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
308 })?;
309
310 let result = rx?.await;
311
312 let outcome = match result {
313 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
314 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
315 };
316
317 Ok(acp::RequestPermissionResponse { outcome })
318 }
319
320 async fn write_text_file(
321 &self,
322 arguments: acp::WriteTextFileRequest,
323 ) -> Result<(), acp::Error> {
324 let cx = &mut self.cx.clone();
325 let task = self
326 .sessions
327 .borrow()
328 .get(&arguments.session_id)
329 .context("Failed to get session")?
330 .thread
331 .update(cx, |thread, cx| {
332 thread.write_text_file(arguments.path, arguments.content, cx)
333 })?;
334
335 task.await?;
336
337 Ok(())
338 }
339
340 async fn read_text_file(
341 &self,
342 arguments: acp::ReadTextFileRequest,
343 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
344 let cx = &mut self.cx.clone();
345 let task = self
346 .sessions
347 .borrow()
348 .get(&arguments.session_id)
349 .context("Failed to get session")?
350 .thread
351 .update(cx, |thread, cx| {
352 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
353 })?;
354
355 let content = task.await?;
356
357 Ok(acp::ReadTextFileResponse { content })
358 }
359
360 async fn session_notification(
361 &self,
362 notification: acp::SessionNotification,
363 ) -> Result<(), acp::Error> {
364 let cx = &mut self.cx.clone();
365 let sessions = self.sessions.borrow();
366 let session = sessions
367 .get(¬ification.session_id)
368 .context("Failed to get session")?;
369
370 session.thread.update(cx, |thread, cx| {
371 thread.handle_session_update(notification.update, cx)
372 })??;
373
374 Ok(())
375 }
376}