1use action_log::ActionLog;
2use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
3use anyhow::anyhow;
4use collections::HashMap;
5use futures::AsyncBufReadExt as _;
6use futures::channel::oneshot;
7use futures::io::BufReader;
8use project::Project;
9use serde::Deserialize;
10use std::path::Path;
11use std::rc::Rc;
12use std::{any::Any, cell::RefCell};
13
14use anyhow::{Context as _, Result};
15use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
16
17use crate::{AgentServerCommand, acp::UnsupportedVersion};
18use acp_thread::{AcpThread, AgentConnection, AuthRequired, LoadError};
19
20pub struct AcpConnection {
21 server_name: &'static str,
22 connection: Rc<acp::ClientSideConnection>,
23 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
24 auth_methods: Vec<acp::AuthMethod>,
25 prompt_capabilities: acp::PromptCapabilities,
26 _io_task: Task<Result<()>>,
27}
28
29pub struct AcpSession {
30 thread: WeakEntity<AcpThread>,
31 pending_cancel: bool,
32}
33
34const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
35
36impl AcpConnection {
37 pub async fn stdio(
38 server_name: &'static str,
39 command: AgentServerCommand,
40 root_dir: &Path,
41 cx: &mut AsyncApp,
42 ) -> Result<Self> {
43 let mut child = util::command::new_smol_command(&command.path)
44 .args(command.args.iter().map(|arg| arg.as_str()))
45 .envs(command.env.iter().flatten())
46 .current_dir(root_dir)
47 .stdin(std::process::Stdio::piped())
48 .stdout(std::process::Stdio::piped())
49 .stderr(std::process::Stdio::piped())
50 .kill_on_drop(true)
51 .spawn()?;
52
53 let stdout = child.stdout.take().context("Failed to take stdout")?;
54 let stdin = child.stdin.take().context("Failed to take stdin")?;
55 let stderr = child.stderr.take().context("Failed to take stderr")?;
56 log::trace!("Spawned (pid: {})", child.id());
57
58 let sessions = Rc::new(RefCell::new(HashMap::default()));
59
60 let client = ClientDelegate {
61 sessions: sessions.clone(),
62 cx: cx.clone(),
63 };
64 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
65 let foreground_executor = cx.foreground_executor().clone();
66 move |fut| {
67 foreground_executor.spawn(fut).detach();
68 }
69 });
70
71 let io_task = cx.background_spawn(io_task);
72
73 cx.background_spawn(async move {
74 let mut stderr = BufReader::new(stderr);
75 let mut line = String::new();
76 while let Ok(n) = stderr.read_line(&mut line).await
77 && n > 0
78 {
79 log::warn!("agent stderr: {}", &line);
80 line.clear();
81 }
82 })
83 .detach();
84
85 cx.spawn({
86 let sessions = sessions.clone();
87 async move |cx| {
88 let status = child.status().await?;
89
90 for session in sessions.borrow().values() {
91 session
92 .thread
93 .update(cx, |thread, cx| {
94 thread.emit_load_error(LoadError::Exited { status }, cx)
95 })
96 .ok();
97 }
98
99 anyhow::Ok(())
100 }
101 })
102 .detach();
103
104 let response = connection
105 .initialize(acp::InitializeRequest {
106 protocol_version: acp::VERSION,
107 client_capabilities: acp::ClientCapabilities {
108 fs: acp::FileSystemCapability {
109 read_text_file: true,
110 write_text_file: true,
111 },
112 },
113 })
114 .await?;
115
116 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
117 return Err(UnsupportedVersion.into());
118 }
119
120 Ok(Self {
121 auth_methods: response.auth_methods,
122 connection: connection.into(),
123 server_name,
124 sessions,
125 prompt_capabilities: response.agent_capabilities.prompt_capabilities,
126 _io_task: io_task,
127 })
128 }
129}
130
131impl AgentConnection for AcpConnection {
132 fn new_thread(
133 self: Rc<Self>,
134 project: Entity<Project>,
135 cwd: &Path,
136 cx: &mut App,
137 ) -> Task<Result<Entity<AcpThread>>> {
138 let conn = self.connection.clone();
139 let sessions = self.sessions.clone();
140 let cwd = cwd.to_path_buf();
141 cx.spawn(async move |cx| {
142 let response = conn
143 .new_session(acp::NewSessionRequest {
144 mcp_servers: vec![],
145 cwd,
146 })
147 .await
148 .map_err(|err| {
149 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
150 let mut error = AuthRequired::new();
151
152 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
153 error = error.with_description(err.message);
154 }
155
156 anyhow!(error)
157 } else {
158 anyhow!(err)
159 }
160 })?;
161
162 let session_id = response.session_id;
163 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
164 let thread = cx.new(|_cx| {
165 AcpThread::new(
166 self.server_name,
167 self.clone(),
168 project,
169 action_log,
170 session_id.clone(),
171 )
172 })?;
173
174 let session = AcpSession {
175 thread: thread.downgrade(),
176 pending_cancel: false,
177 };
178 sessions.borrow_mut().insert(session_id, session);
179
180 Ok(thread)
181 })
182 }
183
184 fn auth_methods(&self) -> &[acp::AuthMethod] {
185 &self.auth_methods
186 }
187
188 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
189 let conn = self.connection.clone();
190 cx.foreground_executor().spawn(async move {
191 let result = conn
192 .authenticate(acp::AuthenticateRequest {
193 method_id: method_id.clone(),
194 })
195 .await?;
196
197 Ok(result)
198 })
199 }
200
201 fn prompt(
202 &self,
203 _id: Option<acp_thread::UserMessageId>,
204 params: acp::PromptRequest,
205 cx: &mut App,
206 ) -> Task<Result<acp::PromptResponse>> {
207 let conn = self.connection.clone();
208 let sessions = self.sessions.clone();
209 let session_id = params.session_id.clone();
210 cx.foreground_executor().spawn(async move {
211 match conn.prompt(params).await {
212 Ok(response) => Ok(response),
213 Err(err) => {
214 if err.code != ErrorCode::INTERNAL_ERROR.code {
215 anyhow::bail!(err)
216 }
217
218 let Some(data) = &err.data else {
219 anyhow::bail!(err)
220 };
221
222 // Temporary workaround until the following PR is generally available:
223 // https://github.com/google-gemini/gemini-cli/pull/6656
224
225 #[derive(Deserialize)]
226 #[serde(deny_unknown_fields)]
227 struct ErrorDetails {
228 details: Box<str>,
229 }
230
231 match serde_json::from_value(data.clone()) {
232 Ok(ErrorDetails { details }) => {
233 if sessions
234 .borrow()
235 .get(&session_id)
236 .is_some_and(|session| session.pending_cancel)
237 && details.contains("This operation was aborted")
238 {
239 Ok(acp::PromptResponse {
240 stop_reason: acp::StopReason::Canceled,
241 })
242 } else {
243 Err(anyhow!(details))
244 }
245 }
246 Err(_) => Err(anyhow!(err)),
247 }
248 }
249 }
250 })
251 }
252
253 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
254 self.prompt_capabilities
255 }
256
257 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
258 if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
259 session.pending_cancel = true;
260 }
261 let conn = self.connection.clone();
262 let params = acp::CancelNotification {
263 session_id: session_id.clone(),
264 };
265 let sessions = self.sessions.clone();
266 let session_id = session_id.clone();
267 cx.foreground_executor()
268 .spawn(async move {
269 let resp = conn.cancel(params).await;
270 if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
271 session.pending_cancel = false;
272 }
273 resp
274 })
275 .detach();
276 }
277
278 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
279 self
280 }
281}
282
283struct ClientDelegate {
284 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
285 cx: AsyncApp,
286}
287
288impl acp::Client for ClientDelegate {
289 async fn request_permission(
290 &self,
291 arguments: acp::RequestPermissionRequest,
292 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
293 let cx = &mut self.cx.clone();
294 let rx = self
295 .sessions
296 .borrow()
297 .get(&arguments.session_id)
298 .context("Failed to get session")?
299 .thread
300 .update(cx, |thread, cx| {
301 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
302 })?;
303
304 let result = rx?.await;
305
306 let outcome = match result {
307 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
308 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
309 };
310
311 Ok(acp::RequestPermissionResponse { outcome })
312 }
313
314 async fn write_text_file(
315 &self,
316 arguments: acp::WriteTextFileRequest,
317 ) -> Result<(), acp::Error> {
318 let cx = &mut self.cx.clone();
319 let task = self
320 .sessions
321 .borrow()
322 .get(&arguments.session_id)
323 .context("Failed to get session")?
324 .thread
325 .update(cx, |thread, cx| {
326 thread.write_text_file(arguments.path, arguments.content, cx)
327 })?;
328
329 task.await?;
330
331 Ok(())
332 }
333
334 async fn read_text_file(
335 &self,
336 arguments: acp::ReadTextFileRequest,
337 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
338 let cx = &mut self.cx.clone();
339 let task = self
340 .sessions
341 .borrow()
342 .get(&arguments.session_id)
343 .context("Failed to get session")?
344 .thread
345 .update(cx, |thread, cx| {
346 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
347 })?;
348
349 let content = task.await?;
350
351 Ok(acp::ReadTextFileResponse { content })
352 }
353
354 async fn session_notification(
355 &self,
356 notification: acp::SessionNotification,
357 ) -> Result<(), acp::Error> {
358 let cx = &mut self.cx.clone();
359 let sessions = self.sessions.borrow();
360 let session = sessions
361 .get(¬ification.session_id)
362 .context("Failed to get session")?;
363
364 session.thread.update(cx, |thread, cx| {
365 thread.handle_session_update(notification.update, cx)
366 })??;
367
368 Ok(())
369 }
370}