acp.rs

  1use crate::AgentServerCommand;
  2use acp_thread::AgentConnection;
  3use acp_tools::AcpConnectionRegistry;
  4use action_log::ActionLog;
  5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::AsyncBufReadExt as _;
  9use futures::io::BufReader;
 10use project::Project;
 11use serde::Deserialize;
 12
 13use std::{any::Any, cell::RefCell};
 14use std::{path::Path, rc::Rc};
 15use thiserror::Error;
 16
 17use anyhow::{Context as _, Result};
 18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 19
 20use acp_thread::{AcpThread, AuthRequired, LoadError};
 21
 22#[derive(Debug, Error)]
 23#[error("Unsupported version")]
 24pub struct UnsupportedVersion;
 25
 26pub struct AcpConnection {
 27    server_name: SharedString,
 28    connection: Rc<acp::ClientSideConnection>,
 29    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 30    auth_methods: Vec<acp::AuthMethod>,
 31    agent_capabilities: acp::AgentCapabilities,
 32    _io_task: Task<Result<()>>,
 33    _wait_task: Task<Result<()>>,
 34    _stderr_task: Task<Result<()>>,
 35}
 36
 37pub struct AcpSession {
 38    thread: WeakEntity<AcpThread>,
 39    suppress_abort_err: bool,
 40}
 41
 42pub async fn connect(
 43    server_name: SharedString,
 44    command: AgentServerCommand,
 45    root_dir: &Path,
 46    cx: &mut AsyncApp,
 47) -> Result<Rc<dyn AgentConnection>> {
 48    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
 49    Ok(Rc::new(conn) as _)
 50}
 51
 52const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 53
 54impl AcpConnection {
 55    pub async fn stdio(
 56        server_name: SharedString,
 57        command: AgentServerCommand,
 58        root_dir: &Path,
 59        cx: &mut AsyncApp,
 60    ) -> Result<Self> {
 61        let mut child = util::command::new_smol_command(command.path)
 62            .args(command.args.iter().map(|arg| arg.as_str()))
 63            .envs(command.env.iter().flatten())
 64            .current_dir(root_dir)
 65            .stdin(std::process::Stdio::piped())
 66            .stdout(std::process::Stdio::piped())
 67            .stderr(std::process::Stdio::piped())
 68            .kill_on_drop(true)
 69            .spawn()?;
 70
 71        let stdout = child.stdout.take().context("Failed to take stdout")?;
 72        let stdin = child.stdin.take().context("Failed to take stdin")?;
 73        let stderr = child.stderr.take().context("Failed to take stderr")?;
 74        log::trace!("Spawned (pid: {})", child.id());
 75
 76        let sessions = Rc::new(RefCell::new(HashMap::default()));
 77
 78        let client = ClientDelegate {
 79            sessions: sessions.clone(),
 80            cx: cx.clone(),
 81        };
 82        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 83            let foreground_executor = cx.foreground_executor().clone();
 84            move |fut| {
 85                foreground_executor.spawn(fut).detach();
 86            }
 87        });
 88
 89        let io_task = cx.background_spawn(io_task);
 90
 91        let stderr_task = cx.background_spawn(async move {
 92            let mut stderr = BufReader::new(stderr);
 93            let mut line = String::new();
 94            while let Ok(n) = stderr.read_line(&mut line).await
 95                && n > 0
 96            {
 97                log::warn!("agent stderr: {}", &line);
 98                line.clear();
 99            }
100            Ok(())
101        });
102
103        let wait_task = cx.spawn({
104            let sessions = sessions.clone();
105            async move |cx| {
106                let status = child.status().await?;
107
108                for session in sessions.borrow().values() {
109                    session
110                        .thread
111                        .update(cx, |thread, cx| {
112                            thread.emit_load_error(LoadError::Exited { status }, cx)
113                        })
114                        .ok();
115                }
116
117                anyhow::Ok(())
118            }
119        });
120
121        let connection = Rc::new(connection);
122
123        cx.update(|cx| {
124            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125                registry.set_active_connection(server_name.clone(), &connection, cx)
126            });
127        })?;
128
129        let response = connection
130            .initialize(acp::InitializeRequest {
131                protocol_version: acp::VERSION,
132                client_capabilities: acp::ClientCapabilities {
133                    fs: acp::FileSystemCapability {
134                        read_text_file: true,
135                        write_text_file: true,
136                    },
137                    terminal: true,
138                },
139            })
140            .await?;
141
142        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
143            return Err(UnsupportedVersion.into());
144        }
145
146        Ok(Self {
147            auth_methods: response.auth_methods,
148            connection,
149            server_name,
150            sessions,
151            agent_capabilities: response.agent_capabilities,
152            _io_task: io_task,
153            _wait_task: wait_task,
154            _stderr_task: stderr_task,
155        })
156    }
157
158    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
159        &self.agent_capabilities.prompt_capabilities
160    }
161}
162
163impl AgentConnection for AcpConnection {
164    fn new_thread(
165        self: Rc<Self>,
166        project: Entity<Project>,
167        cwd: &Path,
168        cx: &mut App,
169    ) -> Task<Result<Entity<AcpThread>>> {
170        let conn = self.connection.clone();
171        let sessions = self.sessions.clone();
172        let cwd = cwd.to_path_buf();
173        let context_server_store = project.read(cx).context_server_store().read(cx);
174        let mcp_servers = context_server_store
175            .configured_server_ids()
176            .iter()
177            .filter_map(|id| {
178                let configuration = context_server_store.configuration_for_server(id)?;
179                let command = configuration.command();
180                Some(acp::McpServer {
181                    name: id.0.to_string(),
182                    command: command.path.clone(),
183                    args: command.args.clone(),
184                    env: if let Some(env) = command.env.as_ref() {
185                        env.iter()
186                            .map(|(name, value)| acp::EnvVariable {
187                                name: name.clone(),
188                                value: value.clone(),
189                            })
190                            .collect()
191                    } else {
192                        vec![]
193                    },
194                })
195            })
196            .collect();
197
198        cx.spawn(async move |cx| {
199            let response = conn
200                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
201                .await
202                .map_err(|err| {
203                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
204                        let mut error = AuthRequired::new();
205
206                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
207                            error = error.with_description(err.message);
208                        }
209
210                        anyhow!(error)
211                    } else {
212                        anyhow!(err)
213                    }
214                })?;
215
216            let session_id = response.session_id;
217            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
218            let thread = cx.new(|cx| {
219                AcpThread::new(
220                    self.server_name.clone(),
221                    self.clone(),
222                    project,
223                    action_log,
224                    session_id.clone(),
225                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
226                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
227                    cx,
228                )
229            })?;
230
231            let session = AcpSession {
232                thread: thread.downgrade(),
233                suppress_abort_err: false,
234            };
235            sessions.borrow_mut().insert(session_id, session);
236
237            Ok(thread)
238        })
239    }
240
241    fn auth_methods(&self) -> &[acp::AuthMethod] {
242        &self.auth_methods
243    }
244
245    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
246        let conn = self.connection.clone();
247        cx.foreground_executor().spawn(async move {
248            let result = conn
249                .authenticate(acp::AuthenticateRequest {
250                    method_id: method_id.clone(),
251                })
252                .await?;
253
254            Ok(result)
255        })
256    }
257
258    fn prompt(
259        &self,
260        _id: Option<acp_thread::UserMessageId>,
261        params: acp::PromptRequest,
262        cx: &mut App,
263    ) -> Task<Result<acp::PromptResponse>> {
264        let conn = self.connection.clone();
265        let sessions = self.sessions.clone();
266        let session_id = params.session_id.clone();
267        cx.foreground_executor().spawn(async move {
268            let result = conn.prompt(params).await;
269
270            let mut suppress_abort_err = false;
271
272            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
273                suppress_abort_err = session.suppress_abort_err;
274                session.suppress_abort_err = false;
275            }
276
277            match result {
278                Ok(response) => Ok(response),
279                Err(err) => {
280                    if err.code != ErrorCode::INTERNAL_ERROR.code {
281                        anyhow::bail!(err)
282                    }
283
284                    let Some(data) = &err.data else {
285                        anyhow::bail!(err)
286                    };
287
288                    // Temporary workaround until the following PR is generally available:
289                    // https://github.com/google-gemini/gemini-cli/pull/6656
290
291                    #[derive(Deserialize)]
292                    #[serde(deny_unknown_fields)]
293                    struct ErrorDetails {
294                        details: Box<str>,
295                    }
296
297                    match serde_json::from_value(data.clone()) {
298                        Ok(ErrorDetails { details }) => {
299                            if suppress_abort_err
300                                && (details.contains("This operation was aborted")
301                                    || details.contains("The user aborted a request"))
302                            {
303                                Ok(acp::PromptResponse {
304                                    stop_reason: acp::StopReason::Cancelled,
305                                })
306                            } else {
307                                Err(anyhow!(details))
308                            }
309                        }
310                        Err(_) => Err(anyhow!(err)),
311                    }
312                }
313            }
314        })
315    }
316
317    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
318        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
319            session.suppress_abort_err = true;
320        }
321        let conn = self.connection.clone();
322        let params = acp::CancelNotification {
323            session_id: session_id.clone(),
324        };
325        cx.foreground_executor()
326            .spawn(async move { conn.cancel(params).await })
327            .detach();
328    }
329
330    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
331        self
332    }
333}
334
335struct ClientDelegate {
336    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
337    cx: AsyncApp,
338}
339
340impl acp::Client for ClientDelegate {
341    async fn request_permission(
342        &self,
343        arguments: acp::RequestPermissionRequest,
344    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
345        let cx = &mut self.cx.clone();
346
347        let task = self
348            .session_thread(&arguments.session_id)?
349            .update(cx, |thread, cx| {
350                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
351            })??;
352
353        let outcome = task.await;
354
355        Ok(acp::RequestPermissionResponse { outcome })
356    }
357
358    async fn write_text_file(
359        &self,
360        arguments: acp::WriteTextFileRequest,
361    ) -> Result<(), acp::Error> {
362        let cx = &mut self.cx.clone();
363        let task = self
364            .session_thread(&arguments.session_id)?
365            .update(cx, |thread, cx| {
366                thread.write_text_file(arguments.path, arguments.content, cx)
367            })?;
368
369        task.await?;
370
371        Ok(())
372    }
373
374    async fn read_text_file(
375        &self,
376        arguments: acp::ReadTextFileRequest,
377    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
378        let task = self.session_thread(&arguments.session_id)?.update(
379            &mut self.cx.clone(),
380            |thread, cx| {
381                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
382            },
383        )?;
384
385        let content = task.await?;
386
387        Ok(acp::ReadTextFileResponse { content })
388    }
389
390    async fn session_notification(
391        &self,
392        notification: acp::SessionNotification,
393    ) -> Result<(), acp::Error> {
394        self.session_thread(&notification.session_id)?
395            .update(&mut self.cx.clone(), |thread, cx| {
396                thread.handle_session_update(notification.update, cx)
397            })??;
398
399        Ok(())
400    }
401
402    async fn create_terminal(
403        &self,
404        args: acp::CreateTerminalRequest,
405    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
406        let terminal = self
407            .session_thread(&args.session_id)?
408            .update(&mut self.cx.clone(), |thread, cx| {
409                thread.create_terminal(
410                    args.command,
411                    args.args,
412                    args.env,
413                    args.cwd,
414                    args.output_byte_limit,
415                    cx,
416                )
417            })?
418            .await?;
419        Ok(
420            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
421                terminal_id: terminal.id().clone(),
422            })?,
423        )
424    }
425
426    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
427        self.session_thread(&args.session_id)?
428            .update(&mut self.cx.clone(), |thread, cx| {
429                thread.kill_terminal(args.terminal_id, cx)
430            })??;
431
432        Ok(())
433    }
434
435    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
436        self.session_thread(&args.session_id)?
437            .update(&mut self.cx.clone(), |thread, cx| {
438                thread.release_terminal(args.terminal_id, cx)
439            })??;
440
441        Ok(())
442    }
443
444    async fn terminal_output(
445        &self,
446        args: acp::TerminalOutputRequest,
447    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
448        self.session_thread(&args.session_id)?
449            .read_with(&mut self.cx.clone(), |thread, cx| {
450                let out = thread
451                    .terminal(args.terminal_id)?
452                    .read(cx)
453                    .current_output(cx);
454
455                Ok(out)
456            })?
457    }
458
459    async fn wait_for_terminal_exit(
460        &self,
461        args: acp::WaitForTerminalExitRequest,
462    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
463        let exit_status = self
464            .session_thread(&args.session_id)?
465            .update(&mut self.cx.clone(), |thread, cx| {
466                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
467            })??
468            .await;
469
470        Ok(acp::WaitForTerminalExitResponse { exit_status })
471    }
472}
473
474impl ClientDelegate {
475    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
476        let sessions = self.sessions.borrow();
477        sessions
478            .get(session_id)
479            .context("Failed to get session")
480            .map(|session| session.thread.clone())
481    }
482}