acp.rs

  1use acp_thread::AgentConnection;
  2use acp_tools::AcpConnectionRegistry;
  3use action_log::ActionLog;
  4use agent_client_protocol::{self as acp, Agent as _, ErrorCode};
  5use anyhow::anyhow;
  6use collections::HashMap;
  7use futures::AsyncBufReadExt as _;
  8use futures::io::BufReader;
  9use project::Project;
 10use project::agent_server_store::AgentServerCommand;
 11use serde::Deserialize;
 12
 13use std::path::PathBuf;
 14use std::{any::Any, cell::RefCell};
 15use std::{path::Path, rc::Rc};
 16use thiserror::Error;
 17
 18use anyhow::{Context as _, Result};
 19use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity};
 20
 21use acp_thread::{AcpThread, AuthRequired, LoadError};
 22
 23#[derive(Debug, Error)]
 24#[error("Unsupported version")]
 25pub struct UnsupportedVersion;
 26
 27pub struct AcpConnection {
 28    server_name: SharedString,
 29    connection: Rc<acp::ClientSideConnection>,
 30    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 31    auth_methods: Vec<acp::AuthMethod>,
 32    agent_capabilities: acp::AgentCapabilities,
 33    root_dir: PathBuf,
 34    _io_task: Task<Result<()>>,
 35    _wait_task: Task<Result<()>>,
 36    _stderr_task: Task<Result<()>>,
 37}
 38
 39pub struct AcpSession {
 40    thread: WeakEntity<AcpThread>,
 41    suppress_abort_err: bool,
 42}
 43
 44pub async fn connect(
 45    server_name: SharedString,
 46    command: AgentServerCommand,
 47    root_dir: &Path,
 48    is_remote: bool,
 49    cx: &mut AsyncApp,
 50) -> Result<Rc<dyn AgentConnection>> {
 51    let conn = AcpConnection::stdio(server_name, command.clone(), root_dir, is_remote, cx).await?;
 52    Ok(Rc::new(conn) as _)
 53}
 54
 55const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
 56
 57impl AcpConnection {
 58    pub async fn stdio(
 59        server_name: SharedString,
 60        command: AgentServerCommand,
 61        root_dir: &Path,
 62        is_remote: bool,
 63        cx: &mut AsyncApp,
 64    ) -> Result<Self> {
 65        let mut child = util::command::new_smol_command(command.path);
 66        child
 67            .args(command.args.iter().map(|arg| arg.as_str()))
 68            .envs(command.env.iter().flatten())
 69            .stdin(std::process::Stdio::piped())
 70            .stdout(std::process::Stdio::piped())
 71            .stderr(std::process::Stdio::piped())
 72            .kill_on_drop(true);
 73        if !is_remote {
 74            child.current_dir(root_dir);
 75        }
 76        let mut child = child.spawn()?;
 77
 78        let stdout = child.stdout.take().context("Failed to take stdout")?;
 79        let stdin = child.stdin.take().context("Failed to take stdin")?;
 80        let stderr = child.stderr.take().context("Failed to take stderr")?;
 81        log::trace!("Spawned (pid: {})", child.id());
 82
 83        let sessions = Rc::new(RefCell::new(HashMap::default()));
 84
 85        let client = ClientDelegate {
 86            sessions: sessions.clone(),
 87            cx: cx.clone(),
 88        };
 89        let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
 90            let foreground_executor = cx.foreground_executor().clone();
 91            move |fut| {
 92                foreground_executor.spawn(fut).detach();
 93            }
 94        });
 95
 96        let io_task = cx.background_spawn(io_task);
 97
 98        let stderr_task = cx.background_spawn(async move {
 99            let mut stderr = BufReader::new(stderr);
100            let mut line = String::new();
101            while let Ok(n) = stderr.read_line(&mut line).await
102                && n > 0
103            {
104                log::warn!("agent stderr: {}", &line);
105                line.clear();
106            }
107            Ok(())
108        });
109
110        let wait_task = cx.spawn({
111            let sessions = sessions.clone();
112            async move |cx| {
113                let status = child.status().await?;
114
115                for session in sessions.borrow().values() {
116                    session
117                        .thread
118                        .update(cx, |thread, cx| {
119                            thread.emit_load_error(LoadError::Exited { status }, cx)
120                        })
121                        .ok();
122                }
123
124                anyhow::Ok(())
125            }
126        });
127
128        let connection = Rc::new(connection);
129
130        cx.update(|cx| {
131            AcpConnectionRegistry::default_global(cx).update(cx, |registry, cx| {
132                registry.set_active_connection(server_name.clone(), &connection, cx)
133            });
134        })?;
135
136        let response = connection
137            .initialize(acp::InitializeRequest {
138                protocol_version: acp::VERSION,
139                client_capabilities: acp::ClientCapabilities {
140                    fs: acp::FileSystemCapability {
141                        read_text_file: true,
142                        write_text_file: true,
143                    },
144                    terminal: true,
145                },
146            })
147            .await?;
148
149        if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
150            return Err(UnsupportedVersion.into());
151        }
152
153        Ok(Self {
154            auth_methods: response.auth_methods,
155            root_dir: root_dir.to_owned(),
156            connection,
157            server_name,
158            sessions,
159            agent_capabilities: response.agent_capabilities,
160            _io_task: io_task,
161            _wait_task: wait_task,
162            _stderr_task: stderr_task,
163        })
164    }
165
166    pub fn prompt_capabilities(&self) -> &acp::PromptCapabilities {
167        &self.agent_capabilities.prompt_capabilities
168    }
169
170    pub fn root_dir(&self) -> &Path {
171        &self.root_dir
172    }
173}
174
175impl AgentConnection for AcpConnection {
176    fn new_thread(
177        self: Rc<Self>,
178        project: Entity<Project>,
179        cwd: &Path,
180        cx: &mut App,
181    ) -> Task<Result<Entity<AcpThread>>> {
182        let conn = self.connection.clone();
183        let sessions = self.sessions.clone();
184        let cwd = cwd.to_path_buf();
185        let context_server_store = project.read(cx).context_server_store().read(cx);
186        let mcp_servers = if project.read(cx).is_local() {
187            context_server_store
188                .configured_server_ids()
189                .iter()
190                .filter_map(|id| {
191                    let configuration = context_server_store.configuration_for_server(id)?;
192                    let command = configuration.command();
193                    Some(acp::McpServer {
194                        name: id.0.to_string(),
195                        command: command.path.clone(),
196                        args: command.args.clone(),
197                        env: if let Some(env) = command.env.as_ref() {
198                            env.iter()
199                                .map(|(name, value)| acp::EnvVariable {
200                                    name: name.clone(),
201                                    value: value.clone(),
202                                })
203                                .collect()
204                        } else {
205                            vec![]
206                        },
207                    })
208                })
209                .collect()
210        } else {
211            // In SSH projects, the external agent is running on the remote
212            // machine, and currently we only run MCP servers on the local
213            // machine. So don't pass any MCP servers to the agent in that case.
214            Vec::new()
215        };
216
217        cx.spawn(async move |cx| {
218            let response = conn
219                .new_session(acp::NewSessionRequest { mcp_servers, cwd })
220                .await
221                .map_err(|err| {
222                    if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
223                        let mut error = AuthRequired::new();
224
225                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
226                            error = error.with_description(err.message);
227                        }
228
229                        anyhow!(error)
230                    } else {
231                        anyhow!(err)
232                    }
233                })?;
234
235            let session_id = response.session_id;
236            let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
237            let thread = cx.new(|cx| {
238                AcpThread::new(
239                    self.server_name.clone(),
240                    self.clone(),
241                    project,
242                    action_log,
243                    session_id.clone(),
244                    // ACP doesn't currently support per-session prompt capabilities or changing capabilities dynamically.
245                    watch::Receiver::constant(self.agent_capabilities.prompt_capabilities),
246                    cx,
247                )
248            })?;
249
250            let session = AcpSession {
251                thread: thread.downgrade(),
252                suppress_abort_err: false,
253            };
254            sessions.borrow_mut().insert(session_id, session);
255
256            Ok(thread)
257        })
258    }
259
260    fn auth_methods(&self) -> &[acp::AuthMethod] {
261        &self.auth_methods
262    }
263
264    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
265        let conn = self.connection.clone();
266        cx.foreground_executor().spawn(async move {
267            let result = conn
268                .authenticate(acp::AuthenticateRequest {
269                    method_id: method_id.clone(),
270                })
271                .await?;
272
273            Ok(result)
274        })
275    }
276
277    fn prompt(
278        &self,
279        _id: Option<acp_thread::UserMessageId>,
280        params: acp::PromptRequest,
281        cx: &mut App,
282    ) -> Task<Result<acp::PromptResponse>> {
283        let conn = self.connection.clone();
284        let sessions = self.sessions.clone();
285        let session_id = params.session_id.clone();
286        cx.foreground_executor().spawn(async move {
287            let result = conn.prompt(params).await;
288
289            let mut suppress_abort_err = false;
290
291            if let Some(session) = sessions.borrow_mut().get_mut(&session_id) {
292                suppress_abort_err = session.suppress_abort_err;
293                session.suppress_abort_err = false;
294            }
295
296            match result {
297                Ok(response) => Ok(response),
298                Err(err) => {
299                    if err.code != ErrorCode::INTERNAL_ERROR.code {
300                        anyhow::bail!(err)
301                    }
302
303                    let Some(data) = &err.data else {
304                        anyhow::bail!(err)
305                    };
306
307                    // Temporary workaround until the following PR is generally available:
308                    // https://github.com/google-gemini/gemini-cli/pull/6656
309
310                    #[derive(Deserialize)]
311                    #[serde(deny_unknown_fields)]
312                    struct ErrorDetails {
313                        details: Box<str>,
314                    }
315
316                    match serde_json::from_value(data.clone()) {
317                        Ok(ErrorDetails { details }) => {
318                            if suppress_abort_err
319                                && (details.contains("This operation was aborted")
320                                    || details.contains("The user aborted a request"))
321                            {
322                                Ok(acp::PromptResponse {
323                                    stop_reason: acp::StopReason::Cancelled,
324                                })
325                            } else {
326                                Err(anyhow!(details))
327                            }
328                        }
329                        Err(_) => Err(anyhow!(err)),
330                    }
331                }
332            }
333        })
334    }
335
336    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
337        if let Some(session) = self.sessions.borrow_mut().get_mut(session_id) {
338            session.suppress_abort_err = true;
339        }
340        let conn = self.connection.clone();
341        let params = acp::CancelNotification {
342            session_id: session_id.clone(),
343        };
344        cx.foreground_executor()
345            .spawn(async move { conn.cancel(params).await })
346            .detach();
347    }
348
349    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
350        self
351    }
352}
353
354struct ClientDelegate {
355    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
356    cx: AsyncApp,
357}
358
359impl acp::Client for ClientDelegate {
360    async fn request_permission(
361        &self,
362        arguments: acp::RequestPermissionRequest,
363    ) -> Result<acp::RequestPermissionResponse, acp::Error> {
364        let cx = &mut self.cx.clone();
365
366        let task = self
367            .session_thread(&arguments.session_id)?
368            .update(cx, |thread, cx| {
369                thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
370            })??;
371
372        let outcome = task.await;
373
374        Ok(acp::RequestPermissionResponse { outcome })
375    }
376
377    async fn write_text_file(
378        &self,
379        arguments: acp::WriteTextFileRequest,
380    ) -> Result<(), acp::Error> {
381        let cx = &mut self.cx.clone();
382        let task = self
383            .session_thread(&arguments.session_id)?
384            .update(cx, |thread, cx| {
385                thread.write_text_file(arguments.path, arguments.content, cx)
386            })?;
387
388        task.await?;
389
390        Ok(())
391    }
392
393    async fn read_text_file(
394        &self,
395        arguments: acp::ReadTextFileRequest,
396    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
397        let task = self.session_thread(&arguments.session_id)?.update(
398            &mut self.cx.clone(),
399            |thread, cx| {
400                thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
401            },
402        )?;
403
404        let content = task.await?;
405
406        Ok(acp::ReadTextFileResponse { content })
407    }
408
409    async fn session_notification(
410        &self,
411        notification: acp::SessionNotification,
412    ) -> Result<(), acp::Error> {
413        self.session_thread(&notification.session_id)?
414            .update(&mut self.cx.clone(), |thread, cx| {
415                thread.handle_session_update(notification.update, cx)
416            })??;
417
418        Ok(())
419    }
420
421    async fn create_terminal(
422        &self,
423        args: acp::CreateTerminalRequest,
424    ) -> Result<acp::CreateTerminalResponse, acp::Error> {
425        let terminal = self
426            .session_thread(&args.session_id)?
427            .update(&mut self.cx.clone(), |thread, cx| {
428                thread.create_terminal(
429                    args.command,
430                    args.args,
431                    args.env,
432                    args.cwd,
433                    args.output_byte_limit,
434                    cx,
435                )
436            })?
437            .await?;
438        Ok(
439            terminal.read_with(&self.cx, |terminal, _| acp::CreateTerminalResponse {
440                terminal_id: terminal.id().clone(),
441            })?,
442        )
443    }
444
445    async fn kill_terminal(&self, args: acp::KillTerminalRequest) -> Result<(), acp::Error> {
446        self.session_thread(&args.session_id)?
447            .update(&mut self.cx.clone(), |thread, cx| {
448                thread.kill_terminal(args.terminal_id, cx)
449            })??;
450
451        Ok(())
452    }
453
454    async fn release_terminal(&self, args: acp::ReleaseTerminalRequest) -> Result<(), acp::Error> {
455        self.session_thread(&args.session_id)?
456            .update(&mut self.cx.clone(), |thread, cx| {
457                thread.release_terminal(args.terminal_id, cx)
458            })??;
459
460        Ok(())
461    }
462
463    async fn terminal_output(
464        &self,
465        args: acp::TerminalOutputRequest,
466    ) -> Result<acp::TerminalOutputResponse, acp::Error> {
467        self.session_thread(&args.session_id)?
468            .read_with(&mut self.cx.clone(), |thread, cx| {
469                let out = thread
470                    .terminal(args.terminal_id)?
471                    .read(cx)
472                    .current_output(cx);
473
474                Ok(out)
475            })?
476    }
477
478    async fn wait_for_terminal_exit(
479        &self,
480        args: acp::WaitForTerminalExitRequest,
481    ) -> Result<acp::WaitForTerminalExitResponse, acp::Error> {
482        let exit_status = self
483            .session_thread(&args.session_id)?
484            .update(&mut self.cx.clone(), |thread, cx| {
485                anyhow::Ok(thread.terminal(args.terminal_id)?.read(cx).wait_for_exit())
486            })??
487            .await;
488
489        Ok(acp::WaitForTerminalExitResponse { exit_status })
490    }
491}
492
493impl ClientDelegate {
494    fn session_thread(&self, session_id: &acp::SessionId) -> Result<WeakEntity<AcpThread>> {
495        let sessions = self.sessions.borrow();
496        sessions
497            .get(session_id)
498            .context("Failed to get session")
499            .map(|session| session.thread.clone())
500    }
501}