1use action_log::ActionLog;
2use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
3use anyhow::anyhow;
4use collections::HashMap;
5use futures::AsyncBufReadExt as _;
6use futures::channel::oneshot;
7use futures::io::BufReader;
8use project::Project;
9use serde::Deserialize;
10use std::path::Path;
11use std::rc::Rc;
12use std::{any::Any, cell::RefCell};
13
14use anyhow::{Context as _, Result};
15use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
16
17use crate::{AgentServerCommand, acp::UnsupportedVersion};
18use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
19
20pub struct AcpConnection {
21 server_name: &'static str,
22 connection: Rc<acp::ClientSideConnection>,
23 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
24 auth_methods: Vec<acp::AuthMethod>,
25 prompt_capabilities: acp::PromptCapabilities,
26 _io_task: Task<Result<()>>,
27}
28
29pub struct AcpSession {
30 thread: WeakEntity<AcpThread>,
31 suppress_abort_err: bool,
32}
33
34const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
35
36impl AcpConnection {
37 pub async fn stdio(
38 server_name: &'static str,
39 command: AgentServerCommand,
40 root_dir: &Path,
41 cx: &mut AsyncApp,
42 ) -> Result<Self> {
43 let mut child = util::command::new_smol_command(&command.path)
44 .args(command.args.iter().map(|arg| arg.as_str()))
45 .envs(command.env.iter().flatten())
46 .current_dir(root_dir)
47 .stdin(std::process::Stdio::piped())
48 .stdout(std::process::Stdio::piped())
49 .stderr(std::process::Stdio::piped())
50 .kill_on_drop(true)
51 .spawn()?;
52
53 let stdout = child.stdout.take().context("Failed to take stdout")?;
54 let stdin = child.stdin.take().context("Failed to take stdin")?;
55 let stderr = child.stderr.take().context("Failed to take stderr")?;
56 log::trace!("Spawned (pid: {})", child.id());
57
58 let sessions = Rc::new(RefCell::new(HashMap::default()));
59
60 let client = ClientDelegate {
61 sessions: sessions.clone(),
62 cx: cx.clone(),
63 };
64 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
65 let foreground_executor = cx.foreground_executor().clone();
66 move |fut| {
67 foreground_executor.spawn(fut).detach();
68 }
69 });
70
71 let io_task = cx.background_spawn(io_task);
72
73 cx.background_spawn(async move {
74 let mut stderr = BufReader::new(stderr);
75 let mut line = String::new();
76 while let Ok(n) = stderr.read_line(&mut line).await
77 && n > 0
78 {
79 log::warn!("agent stderr: {}", &line);
80 line.clear();
81 }
82 })
83 .detach();
84
85 cx.spawn({
86 let sessions = sessions.clone();
87 async move |cx| {
88 let status = child.status().await?;
89
90 for session in sessions.borrow().values() {
91 session
92 .thread
93 .update(cx, |thread, cx| {
94 thread.emit_load_error(LoadError::Exited { status }, cx)
95 })
96 .ok();
97 }
98
99 anyhow::Ok(())
100 }
101 })
102 .detach();
103
104 let response = connection
105 .initialize(acp::InitializeRequest {
106 protocol_version: acp::VERSION,
107 client_capabilities: acp::ClientCapabilities {
108 fs: acp::FileSystemCapability {
109 read_text_file: true,
110 write_text_file: true,
111 },
112 },
113 })
114 .await?;
115
116 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
117 return Err(UnsupportedVersion.into());
118 }
119
120 Ok(Self {
121 auth_methods: response.auth_methods,
122 connection: connection.into(),
123 server_name,
124 sessions,
125 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
126 _io_task: io_task,
127 })
128 }
129}
130
131impl AgentConnection for AcpConnection {
132 fn new_thread(
133 self: Rc<Self>,
134 project: Entity<Project>,
135 cwd: &Path,
136 cx: &mut App,
137 ) -> Task<Result<Entity<AcpThread>>> {
138 let conn = self.connection.clone();
139 let sessions = self.sessions.clone();
140 let cwd = cwd.to_path_buf();
141 cx.spawn(async move |cx| {
142 let response = conn
143 .new_session(acp::NewSessionRequest {
144 mcp_servers: vec![],
145 cwd,
146 })
147 .await
148 .map_err(|err| {
149 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
150 let mut error = AuthRequired::new();
151
152 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
153 error = error.with_description(err.message);
154 }
155
156 anyhow!(error)
157 } else {
158 anyhow!(err)
159 }
160 })?;
161
162 let session_id = response.session_id;
163 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
164 let thread = cx.new(|_cx| {
165 AcpThread::new(
166 self.server_name,
167 self.clone(),
168 project,
169 action_log,
170 session_id.clone(),
171 )
172 })?;
173
174 let session = AcpSession {
175 thread: thread.downgrade(),
176 suppress_abort_err: false,
177 };
178 sessions.borrow_mut().insert(session_id, session);
179
180 Ok(thread)
181 })
182 }
183
184 fn auth_methods(&self) -> &[acp::AuthMethod] {
185 &self.auth_methods
186 }
187
188 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
189 let conn = self.connection.clone();
190 cx.foreground_executor().spawn(async move {
191 let result = conn
192 .authenticate(acp::AuthenticateRequest {
193 method_id: method_id.clone(),
194 })
195 .await?;
196
197 Ok(result)
198 })
199 }
200
201 fn prompt(
202 &self,
203 _id: Option<acp_thread::UserMessageId>,
204 params: acp::PromptRequest,
205 cx: &mut App,
206 ) -> Task<Result<acp::PromptResponse>> {
207 let conn = self.connection.clone();
208 let sessions = self.sessions.clone();
209 let session_id = params.session_id.clone();
210 cx.foreground_executor().spawn(async move {
211 let result = conn.prompt(params).await;
212
213 let mut suppress_abort_err = false;
214
215 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
216 suppress_abort_err = session.suppress_abort_err;
217 session.suppress_abort_err = false;
218 }
219
220 match result {
221 Ok(response) => Ok(response),
222 Err(err) => {
223 if err.code != ErrorCode::INTERNAL_ERROR.code {
224 anyhow::bail!(err)
225 }
226
227 let Some(data) = &err.data else {
228 anyhow::bail!(err)
229 };
230
231 // Temporary workaround until the following PR is generally available:
232 // https://github.com/google-gemini/gemini-cli/pull/6656
233
234 #[derive(Deserialize)]
235 #[serde(deny_unknown_fields)]
236 struct ErrorDetails {
237 details: Box<str>,
238 }
239
240 match serde_json::from_value(data.clone()) {
241 Ok(ErrorDetails { details }) => {
242 if suppress_abort_err && details.contains("This operation was aborted")
243 {
244 Ok(acp::PromptResponse {
245 stop_reason: acp::StopReason::Canceled,
246 })
247 } else {
248 Err(anyhow!(details))
249 }
250 }
251 Err(_) => Err(anyhow!(err)),
252 }
253 }
254 }
255 })
256 }
257
258 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
259 self.prompt_capabilities
260 }
261
262 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
263 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
264 session.suppress_abort_err = true;
265 }
266 let conn = self.connection.clone();
267 let params = acp::CancelNotification {
268 session_id: session_id.clone(),
269 };
270 cx.foreground_executor()
271 .spawn(async move { conn.cancel(params).await })
272 .detach();
273 }
274
275 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
276 self
277 }
278}
279
280struct ClientDelegate {
281 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
282 cx: AsyncApp,
283}
284
285impl acp::Client for ClientDelegate {
286 async fn request_permission(
287 &self,
288 arguments: acp::RequestPermissionRequest,
289 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
290 let cx = &mut self.cx.clone();
291 let rx = self
292 .sessions
293 .borrow()
294 .get(&arguments.session_id)
295 .context("Failed to get session")?
296 .thread
297 .update(cx, |thread, cx| {
298 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
299 })?;
300
301 let result = rx?.await;
302
303 let outcome = match result {
304 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
305 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
306 };
307
308 Ok(acp::RequestPermissionResponse { outcome })
309 }
310
311 async fn write_text_file(
312 &self,
313 arguments: acp::WriteTextFileRequest,
314 ) -> Result<(), acp::Error> {
315 let cx = &mut self.cx.clone();
316 let task = self
317 .sessions
318 .borrow()
319 .get(&arguments.session_id)
320 .context("Failed to get session")?
321 .thread
322 .update(cx, |thread, cx| {
323 thread.write_text_file(arguments.path, arguments.content, cx)
324 })?;
325
326 task.await?;
327
328 Ok(())
329 }
330
331 async fn read_text_file(
332 &self,
333 arguments: acp::ReadTextFileRequest,
334 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
335 let cx = &mut self.cx.clone();
336 let task = self
337 .sessions
338 .borrow()
339 .get(&arguments.session_id)
340 .context("Failed to get session")?
341 .thread
342 .update(cx, |thread, cx| {
343 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
344 })?;
345
346 let content = task.await?;
347
348 Ok(acp::ReadTextFileResponse { content })
349 }
350
351 async fn session_notification(
352 &self,
353 notification: acp::SessionNotification,
354 ) -> Result<(), acp::Error> {
355 let cx = &mut self.cx.clone();
356 let sessions = self.sessions.borrow();
357 let session = sessions
358 .get(¬ification.session_id)
359 .context("Failed to get session")?;
360
361 session.thread.update(cx, |thread, cx| {
362 thread.handle_session_update(notification.update, cx)
363 })??;
364
365 Ok(())
366 }
367}