acp.rs

  1use crate::AgentServerCommand;
  2use acp_thread::AgentConnection;
  3use acp_tools::AcpConnectionRegistry;
  4use action_log::ActionLog;
  5use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  6use anyhow::anyhow;
  7use collections::HashMap;
  8use futures::AsyncBufReadExt as _;
  9use futures::io::BufReader;
 10use project::Project;
 11use serde::Deserialize;
 12
 13use std::{any::Any, cell::RefCell};
 14use std::{path::Path, rc::Rc};
 15use thiserror::Error;
 16
 17use anyhow::{Context as _, Result};
 18use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 19
 20use acp_thread::{AcpThread, AuthRequired, LoadError};
 21
 22#[derive(Debug, Error)]
 23#[error("Unsupported version")]
 24pub struct UnsupportedVersion;
 25
 26pub struct AcpConnection {
 27    server_name: SharedString,
 28    connection: Rc<acp::ClientSideConnection>,
 29    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 30    auth_methods: Vec<acp::AuthMethod>,
 31    agent_capabilities: acp::AgentCapabilities,
 32    _io_task: Task<Result<()>>,
 33    _wait_task: Task<Result<()>>,
 34    _stderr_task: Task<Result<()>>,
 35}
 36
 37pub struct AcpSession {
 38    thread: WeakEntity<AcpThread>,
 39    suppress_abort_err: bool,
 40}
 41
 42pub async fn connect(
 43    server_name: SharedString,
 44    command: AgentServerCommand,
 45    root_dir: &Path,
 46    cx: &mut AsyncApp,
 47) -> Result<Rc<dyn AgentConnection>> {
 48    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, cx).await?;
 49    Ok(Rc::new(conn) as _)
 50}
 51
 52const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 53
 54impl AcpConnection {
 55    pub async fn stdio(
 56        server_name: SharedString,
 57        command: AgentServerCommand,
 58        root_dir: &Path,
 59        cx: &mut AsyncApp,
 60    ) -> Result<Self> {
 61        let mut child = util::command::new_smol_command(command.path)
 62            .args(command.args.iter().map(|arg| arg.as_str()))
 63            .envs(command.env.iter().flatten())
 64            .current_dir(root_dir)
 65            .stdin(std::process::Stdio::piped())
 66            .stdout(std::process::Stdio::piped())
 67            .stderr(std::process::Stdio::piped())
 68            .kill_on_drop(true)
 69            .spawn()?;
 70
 71        let stdout = child.stdout.take().context("Failed to take stdout")?;
 72        let stdin = child.stdin.take().context("Failed to take stdin")?;
 73        let stderr = child.stderr.take().context("Failed to take stderr")?;
 74        log::trace!("Spawned (pid: {})", child.id());
 75
 76        let sessions = Rc::new(RefCell::new(HashMap::default()));
 77
 78        let client = ClientDelegate {
 79            sessions: sessions.clone(),
 80            cx: cx.clone(),
 81        };
 82        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 83            let foreground_executor = cx.foreground_executor().clone();
 84            move |fut| {
 85                foreground_executor.spawn(fut).detach();
 86            }
 87        });
 88
 89        let io_task = cx.background_spawn(io_task);
 90
 91        let stderr_task = cx.background_spawn(async move {
 92            let mut stderr = BufReader::new(stderr);
 93            let mut line = String::new();
 94            while let Ok(n) = stderr.read_line(&mut line).await
 95                && n > 0
 96            {
 97                log::warn!("agent stderr: {}", &line);
 98                line.clear();
 99            }
100            Ok(())
101        });
102
103        let wait_task = cx.spawn({
104            let sessions = sessions.clone();
105            async move |cx| {
106                let status = child.status().await?;
107
108                for session in sessions.borrow().values() {
109                    session
110                        .thread
111                        .update(cx, |thread, cx| {
112                            thread.emit_load_error(LoadError::Exited { status }, cx)
113                        })
114                        .ok();
115                }
116
117                anyhow::Ok(())
118            }
119        });
120
121        let connection = Rc::new(connection);
122
123        cx.update(|cx| {
124            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
125                registry.set_active_connection(server_name.clone(), &connection, cx)
126            });
127        })?;
128
129        let response = connection
130            .initialize(acp::InitializeRequest {
131                protocol_version: acp::VERSION,
132                client_capabilities: acp::ClientCapabilities {
133                    fs: acp::FileSystemCapability {
134                        read_text_file: true,
135                        write_text_file: true,
136                    },
137                    terminal: true,
138                },
139            })
140            .await?;
141
142        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
143            return Err(UnsupportedVersion.into());
144        }
145
146        Ok(Self {
147            auth_methods: response.auth_methods,
148            connection,
149            server_name,
150            sessions,
151            agent_capabilities: response.agent_capabilities,
152            _io_task: io_task,
153            _wait_task: wait_task,
154            _stderr_task: stderr_task,
155        })
156    }
157
158    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
159        &self.agent_capabilities.prompt_capabilities
160    }
161}
162
163impl AgentConnection for AcpConnection {
164    fn new_thread(
165        self: Rc<Self>,
166        project: Entity<Project>,
167        cwd: &Path,
168        cx: &mut App,
169    ) -> Task<Result<Entity<AcpThread>>> {
170        let conn = self.connection.clone();
171        let sessions = self.sessions.clone();
172        let cwd = cwd.to_path_buf();
173        let context_server_store = project.read(cx).context_server_store().read(cx);
174        let mcp_servers = context_server_store
175            .configured_server_ids()
176            .iter()
177            .filter_map(|id| {
178                let configuration = context_server_store.configuration_for_server(id)?;
179                let command = configuration.command();
180                Some(acp::McpServer {
181                    name: id.0.to_string(),
182                    command: command.path.clone(),
183                    args: command.args.clone(),
184                    env: if let Some(env) = command.env.as_ref() {
185                        env.iter()
186                            .map(|(name, value)| acp::EnvVariable {
187                                name: name.clone(),
188                                value: value.clone(),
189                            })
190                            .collect()
191                    } else {
192                        vec![]
193                    },
194                })
195            })
196            .collect();
197
198        cx.spawn(async move |cx| {
199            let response = conn
200                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
201                .await
202                .map_err(|err| {
203                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
204                        let mut error = AuthRequired::new();
205
206                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
207                            error = error.with_description(err.message);
208                        }
209
210                        anyhow!(error)
211                    } else {
212                        anyhow!(err)
213                    }
214                })?;
215
216            let session_id = response.session_id;
217            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
218            let thread = cx.new(|cx| {
219                AcpThread::new(
220                    self.server_name.clone(),
221                    self.clone(),
222                    project,
223                    action_log,
224                    session_id.clone(),
225                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
226                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
227                    response.available_commands,
228                    cx,
229                )
230            })?;
231
232            let session = AcpSession {
233                thread: thread.downgrade(),
234                suppress_abort_err: false,
235            };
236            sessions.borrow_mut().insert(session_id, session);
237
238            Ok(thread)
239        })
240    }
241
242    fn auth_methods(&self) -> &[acp::AuthMethod] {
243        &self.auth_methods
244    }
245
246    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
247        let conn = self.connection.clone();
248        cx.foreground_executor().spawn(async move {
249            let result = conn
250                .authenticate(acp::AuthenticateRequest {
251                    method_id: method_id.clone(),
252                })
253                .await?;
254
255            Ok(result)
256        })
257    }
258
259    fn prompt(
260        &self,
261        _id: Option<acp_thread::UserMessageId>,
262        params: acp::PromptRequest,
263        cx: &mut App,
264    ) -> Task<Result<acp::PromptResponse>> {
265        let conn = self.connection.clone();
266        let sessions = self.sessions.clone();
267        let session_id = params.session_id.clone();
268        cx.foreground_executor().spawn(async move {
269            let result = conn.prompt(params).await;
270
271            let mut suppress_abort_err = false;
272
273            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
274                suppress_abort_err = session.suppress_abort_err;
275                session.suppress_abort_err = false;
276            }
277
278            match result {
279                Ok(response) => Ok(response),
280                Err(err) => {
281                    if err.code != ErrorCode::INTERNAL_ERROR.code {
282                        anyhow::bail!(err)
283                    }
284
285                    let Some(data) = &err.data else {
286                        anyhow::bail!(err)
287                    };
288
289                    // Temporary workaround until the following PR is generally available:
290                    // https://github.com/google-gemini/gemini-cli/pull/6656
291
292                    #[derive(Deserialize)]
293                    #[serde(deny_unknown_fields)]
294                    struct ErrorDetails {
295                        details: Box<str>,
296                    }
297
298                    match serde_json::from_value(data.clone()) {
299                        Ok(ErrorDetails { details }) => {
300                            if suppress_abort_err
301                                && (details.contains("This operation was aborted")
302                                    || details.contains("The user aborted a request"))
303                            {
304                                Ok(acp::PromptResponse {
305                                    stop_reason: acp::StopReason::Cancelled,
306                                })
307                            } else {
308                                Err(anyhow!(details))
309                            }
310                        }
311                        Err(_) => Err(anyhow!(err)),
312                    }
313                }
314            }
315        })
316    }
317
318    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
319        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
320            session.suppress_abort_err = true;
321        }
322        let conn = self.connection.clone();
323        let params = acp::CancelNotification {
324            session_id: session_id.clone(),
325        };
326        cx.foreground_executor()
327            .spawn(async move { conn.cancel(params).await })
328            .detach();
329    }
330
331    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
332        self
333    }
334}
335
336struct ClientDelegate {
337    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
338    cx: AsyncApp,
339}
340
341impl acp::Client for ClientDelegate {
342    async fn request_permission(
343        &self,
344        arguments: acp::RequestPermissionRequest,
345    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
346        let cx = &mut self.cx.clone();
347
348        let task = self
349            .session_thread(&arguments.session_id)?
350            .update(cx, |thread, cx| {
351                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
352            })??;
353
354        let outcome = task.await;
355
356        Ok(acp::RequestPermissionResponse { outcome })
357    }
358
359    async fn write_text_file(
360        &self,
361        arguments: acp::WriteTextFileRequest,
362    ) -> Result<(), acp::Error> {
363        let cx = &mut self.cx.clone();
364        let task = self
365            .session_thread(&arguments.session_id)?
366            .update(cx, |thread, cx| {
367                thread.write_text_file(arguments.path, arguments.content, cx)
368            })?;
369
370        task.await?;
371
372        Ok(())
373    }
374
375    async fn read_text_file(
376        &self,
377        arguments: acp::ReadTextFileRequest,
378    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
379        let task = self.session_thread(&arguments.session_id)?.update(
380            &mut self.cx.clone(),
381            |thread, cx| {
382                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
383            },
384        )?;
385
386        let content = task.await?;
387
388        Ok(acp::ReadTextFileResponse { content })
389    }
390
391    async fn session_notification(
392        &self,
393        notification: acp::SessionNotification,
394    ) -> Result<(), acp::Error> {
395        self.session_thread(&notification.session_id)?
396            .update(&mut self.cx.clone(), |thread, cx| {
397                thread.handle_session_update(notification.update, cx)
398            })??;
399
400        Ok(())
401    }
402
403    async fn create_terminal(
404        &self,
405        args: acp::CreateTerminalRequest,
406    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
407        let terminal = self
408            .session_thread(&args.session_id)?
409            .update(&mut self.cx.clone(), |thread, cx| {
410                thread.create_terminal(
411                    args.command,
412                    args.args,
413                    args.env,
414                    args.cwd,
415                    args.output_byte_limit,
416                    cx,
417                )
418            })?
419            .await?;
420        Ok(
421            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
422                terminal_id: terminal.id().clone(),
423            })?,
424        )
425    }
426
427    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
428        self.session_thread(&args.session_id)?
429            .update(&mut self.cx.clone(), |thread, cx| {
430                thread.kill_terminal(args.terminal_id, cx)
431            })??;
432
433        Ok(())
434    }
435
436    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
437        self.session_thread(&args.session_id)?
438            .update(&mut self.cx.clone(), |thread, cx| {
439                thread.release_terminal(args.terminal_id, cx)
440            })??;
441
442        Ok(())
443    }
444
445    async fn terminal_output(
446        &self,
447        args: acp::TerminalOutputRequest,
448    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
449        self.session_thread(&args.session_id)?
450            .read_with(&mut self.cx.clone(), |thread, cx| {
451                let out = thread
452                    .terminal(args.terminal_id)?
453                    .read(cx)
454                    .current_output(cx);
455
456                Ok(out)
457            })?
458    }
459
460    async fn wait_for_terminal_exit(
461        &self,
462        args: acp::WaitForTerminalExitRequest,
463    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
464        let exit_status = self
465            .session_thread(&args.session_id)?
466            .update(&mut self.cx.clone(), |thread, cx| {
467                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
468            })??
469            .await;
470
471        Ok(acp::WaitForTerminalExitResponse { exit_status })
472    }
473}
474
475impl ClientDelegate {
476    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
477        let sessions = self.sessions.borrow();
478        sessions
479            .get(session_id)
480            .context("Failed to get session")
481            .map(|session| session.thread.clone())
482    }
483}