acp.rs

  1use crate::AgentServerCommand;
  2use acp_thread::AgentConnection;
  3use acp_tools::AcpConnectionRegistry;
  4use action_log::ActionLog;
  5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::AsyncBufReadExt as _;
  9use futures::io::BufReader;
 10use project::Project;
 11use serde::Deserialize;
 12
 13use std::{any::Any, cell::RefCell};
 14use std::{path::Path, rc::Rc};
 15use thiserror::Error;
 16
 17use anyhow::{Context as _, Result};
 18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 19
 20use acp_thread::{AcpThread, AuthRequired, LoadError};
 21
 22#[derive(Debug, Error)]
 23#[error("Unsupported version")]
 24pub struct UnsupportedVersion;
 25
 26pub struct AcpConnection {
 27    server_name: SharedString,
 28    connection: Rc<acp::ClientSideConnection>,
 29    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 30    auth_methods: Vec<acp::AuthMethod>,
 31    prompt_capabilities: acp::PromptCapabilities,
 32    _io_task: Task<Result<()>>,
 33    _wait_task: Task<Result<()>>,
 34    _stderr_task: Task<Result<()>>,
 35}
 36
 37pub struct AcpSession {
 38    thread: WeakEntity<AcpThread>,
 39    suppress_abort_err: bool,
 40}
 41
 42pub async fn connect(
 43    server_name: SharedString,
 44    command: AgentServerCommand,
 45    root_dir: &Path,
 46    cx: &mut AsyncApp,
 47) -> Result<Rc<dyn AgentConnection>> {
 48    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
 49    Ok(Rc::new(conn) as _)
 50}
 51
 52const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 53
 54impl AcpConnection {
 55    pub async fn stdio(
 56        server_name: SharedString,
 57        command: AgentServerCommand,
 58        root_dir: &Path,
 59        cx: &mut AsyncApp,
 60    ) -> Result<Self> {
 61        let mut child = util::command::new_smol_command(command.path)
 62            .args(command.args.iter().map(|arg| arg.as_str()))
 63            .envs(command.env.iter().flatten())
 64            .current_dir(root_dir)
 65            .stdin(std::process::Stdio::piped())
 66            .stdout(std::process::Stdio::piped())
 67            .stderr(std::process::Stdio::piped())
 68            .kill_on_drop(true)
 69            .spawn()?;
 70
 71        let stdout = child.stdout.take().context("Failed to take stdout")?;
 72        let stdin = child.stdin.take().context("Failed to take stdin")?;
 73        let stderr = child.stderr.take().context("Failed to take stderr")?;
 74        log::trace!("Spawned (pid: {})", child.id());
 75
 76        let sessions = Rc::new(RefCell::new(HashMap::default()));
 77
 78        let client = ClientDelegate {
 79            sessions: sessions.clone(),
 80            cx: cx.clone(),
 81        };
 82        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 83            let foreground_executor = cx.foreground_executor().clone();
 84            move |fut| {
 85                foreground_executor.spawn(fut).detach();
 86            }
 87        });
 88
 89        let io_task = cx.background_spawn(io_task);
 90
 91        let stderr_task = cx.background_spawn(async move {
 92            let mut stderr = BufReader::new(stderr);
 93            let mut line = String::new();
 94            while let Ok(n) = stderr.read_line(&mut line).await
 95                && n > 0
 96            {
 97                log::warn!("agent stderr: {}", &line);
 98                line.clear();
 99            }
100            Ok(())
101        });
102
103        let wait_task = cx.spawn({
104            let sessions = sessions.clone();
105            async move |cx| {
106                let status = child.status().await?;
107
108                for session in sessions.borrow().values() {
109                    session
110                        .thread
111                        .update(cx, |thread, cx| {
112                            thread.emit_load_error(LoadError::Exited { status }, cx)
113                        })
114                        .ok();
115                }
116
117                anyhow::Ok(())
118            }
119        });
120
121        let connection = Rc::new(connection);
122
123        cx.update(|cx| {
124            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125                registry.set_active_connection(server_name.clone(), &connection, cx)
126            });
127        })?;
128
129        let response = connection
130            .initialize(acp::InitializeRequest {
131                protocol_version: acp::VERSION,
132                client_capabilities: acp::ClientCapabilities {
133                    fs: acp::FileSystemCapability {
134                        read_text_file: true,
135                        write_text_file: true,
136                    },
137                },
138            })
139            .await?;
140
141        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
142            return Err(UnsupportedVersion.into());
143        }
144
145        Ok(Self {
146            auth_methods: response.auth_methods,
147            connection,
148            server_name,
149            sessions,
150            prompt_capabilities: response.agent_capabilities.prompt_capabilities,
151            _io_task: io_task,
152            _wait_task: wait_task,
153            _stderr_task: stderr_task,
154        })
155    }
156
157    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
158        &self.prompt_capabilities
159    }
160}
161
162impl AgentConnection for AcpConnection {
163    fn new_thread(
164        self: Rc<Self>,
165        project: Entity<Project>,
166        cwd: &Path,
167        cx: &mut App,
168    ) -> Task<Result<Entity<AcpThread>>> {
169        let conn = self.connection.clone();
170        let sessions = self.sessions.clone();
171        let cwd = cwd.to_path_buf();
172        let context_server_store = project.read(cx).context_server_store().read(cx);
173        let mcp_servers = context_server_store
174            .configured_server_ids()
175            .iter()
176            .filter_map(|id| {
177                let configuration = context_server_store.configuration_for_server(id)?;
178                let command = configuration.command();
179                Some(acp::McpServer {
180                    name: id.0.to_string(),
181                    command: command.path.clone(),
182                    args: command.args.clone(),
183                    env: if let Some(env) = command.env.as_ref() {
184                        env.iter()
185                            .map(|(name, value)| acp::EnvVariable {
186                                name: name.clone(),
187                                value: value.clone(),
188                            })
189                            .collect()
190                    } else {
191                        vec![]
192                    },
193                })
194            })
195            .collect();
196
197        cx.spawn(async move |cx| {
198            let response = conn
199                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
200                .await
201                .map_err(|err| {
202                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
203                        let mut error = AuthRequired::new();
204
205                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
206                            error = error.with_description(err.message);
207                        }
208
209                        anyhow!(error)
210                    } else {
211                        anyhow!(err)
212                    }
213                })?;
214
215            let session_id = response.session_id;
216            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
217            let thread = cx.new(|cx| {
218                AcpThread::new(
219                    self.server_name.clone(),
220                    self.clone(),
221                    project,
222                    action_log,
223                    session_id.clone(),
224                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
225                    watch::Receiver::constant(self.prompt_capabilities),
226                    cx,
227                )
228            })?;
229
230            let session = AcpSession {
231                thread: thread.downgrade(),
232                suppress_abort_err: false,
233            };
234            sessions.borrow_mut().insert(session_id, session);
235
236            Ok(thread)
237        })
238    }
239
240    fn auth_methods(&self) -> &[acp::AuthMethod] {
241        &self.auth_methods
242    }
243
244    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
245        let conn = self.connection.clone();
246        cx.foreground_executor().spawn(async move {
247            let result = conn
248                .authenticate(acp::AuthenticateRequest {
249                    method_id: method_id.clone(),
250                })
251                .await?;
252
253            Ok(result)
254        })
255    }
256
257    fn prompt(
258        &self,
259        _id: Option<acp_thread::UserMessageId>,
260        params: acp::PromptRequest,
261        cx: &mut App,
262    ) -> Task<Result<acp::PromptResponse>> {
263        let conn = self.connection.clone();
264        let sessions = self.sessions.clone();
265        let session_id = params.session_id.clone();
266        cx.foreground_executor().spawn(async move {
267            let result = conn.prompt(params).await;
268
269            let mut suppress_abort_err = false;
270
271            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
272                suppress_abort_err = session.suppress_abort_err;
273                session.suppress_abort_err = false;
274            }
275
276            match result {
277                Ok(response) => Ok(response),
278                Err(err) => {
279                    if err.code != ErrorCode::INTERNAL_ERROR.code {
280                        anyhow::bail!(err)
281                    }
282
283                    let Some(data) = &err.data else {
284                        anyhow::bail!(err)
285                    };
286
287                    // Temporary workaround until the following PR is generally available:
288                    // https://github.com/google-gemini/gemini-cli/pull/6656
289
290                    #[derive(Deserialize)]
291                    #[serde(deny_unknown_fields)]
292                    struct ErrorDetails {
293                        details: Box<str>,
294                    }
295
296                    match serde_json::from_value(data.clone()) {
297                        Ok(ErrorDetails { details }) => {
298                            if suppress_abort_err
299                                && (details.contains("This operation was aborted")
300                                    || details.contains("The user aborted a request"))
301                            {
302                                Ok(acp::PromptResponse {
303                                    stop_reason: acp::StopReason::Cancelled,
304                                })
305                            } else {
306                                Err(anyhow!(details))
307                            }
308                        }
309                        Err(_) => Err(anyhow!(err)),
310                    }
311                }
312            }
313        })
314    }
315
316    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
317        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
318            session.suppress_abort_err = true;
319        }
320        let conn = self.connection.clone();
321        let params = acp::CancelNotification {
322            session_id: session_id.clone(),
323        };
324        cx.foreground_executor()
325            .spawn(async move { conn.cancel(params).await })
326            .detach();
327    }
328
329    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
330        self
331    }
332}
333
334struct ClientDelegate {
335    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
336    cx: AsyncApp,
337}
338
339impl acp::Client for ClientDelegate {
340    async fn request_permission(
341        &self,
342        arguments: acp::RequestPermissionRequest,
343    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
344        let cx = &mut self.cx.clone();
345
346        let task = self
347            .sessions
348            .borrow()
349            .get(&arguments.session_id)
350            .context("Failed to get session")?
351            .thread
352            .update(cx, |thread, cx| {
353                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
354            })??;
355
356        let outcome = task.await;
357
358        Ok(acp::RequestPermissionResponse { outcome })
359    }
360
361    async fn write_text_file(
362        &self,
363        arguments: acp::WriteTextFileRequest,
364    ) -> Result<(), acp::Error> {
365        let cx = &mut self.cx.clone();
366        let task = self
367            .sessions
368            .borrow()
369            .get(&arguments.session_id)
370            .context("Failed to get session")?
371            .thread
372            .update(cx, |thread, cx| {
373                thread.write_text_file(arguments.path, arguments.content, cx)
374            })?;
375
376        task.await?;
377
378        Ok(())
379    }
380
381    async fn read_text_file(
382        &self,
383        arguments: acp::ReadTextFileRequest,
384    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
385        let cx = &mut self.cx.clone();
386        let task = self
387            .sessions
388            .borrow()
389            .get(&arguments.session_id)
390            .context("Failed to get session")?
391            .thread
392            .update(cx, |thread, cx| {
393                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
394            })?;
395
396        let content = task.await?;
397
398        Ok(acp::ReadTextFileResponse { content })
399    }
400
401    async fn session_notification(
402        &self,
403        notification: acp::SessionNotification,
404    ) -> Result<(), acp::Error> {
405        let cx = &mut self.cx.clone();
406        let sessions = self.sessions.borrow();
407        let session = sessions
408            .get(&notification.session_id)
409            .context("Failed to get session")?;
410
411        session.thread.update(cx, |thread, cx| {
412            thread.handle_session_update(notification.update, cx)
413        })??;
414
415        Ok(())
416    }
417}