stdio_agent_server.rs

  1use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
  2use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
  3use agentic_coding_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use gpui::{App, AsyncApp, Entity, Task, prelude::*};
  6use project::Project;
  7use std::path::Path;
  8use util::ResultExt;
  9
 10pub trait StdioAgentServer: Send + Clone {
 11    fn logo(&self) -> ui::IconName;
 12    fn name(&self) -> &'static str;
 13    fn empty_state_headline(&self) -> &'static str;
 14    fn empty_state_message(&self) -> &'static str;
 15    fn supports_always_allow(&self) -> bool;
 16
 17    fn command(
 18        &self,
 19        project: &Entity<Project>,
 20        cx: &mut AsyncApp,
 21    ) -> impl Future<Output = Result<AgentServerCommand>>;
 22
 23    fn version(
 24        &self,
 25        command: &AgentServerCommand,
 26    ) -> impl Future<Output = Result<AgentServerVersion>> + Send;
 27}
 28
 29impl<T: StdioAgentServer + 'static> AgentServer for T {
 30    fn name(&self) -> &'static str {
 31        self.name()
 32    }
 33
 34    fn empty_state_headline(&self) -> &'static str {
 35        self.empty_state_headline()
 36    }
 37
 38    fn empty_state_message(&self) -> &'static str {
 39        self.empty_state_message()
 40    }
 41
 42    fn logo(&self) -> ui::IconName {
 43        self.logo()
 44    }
 45
 46    fn supports_always_allow(&self) -> bool {
 47        self.supports_always_allow()
 48    }
 49
 50    fn new_thread(
 51        &self,
 52        root_dir: &Path,
 53        project: &Entity<Project>,
 54        cx: &mut App,
 55    ) -> Task<Result<Entity<AcpThread>>> {
 56        let root_dir = root_dir.to_path_buf();
 57        let project = project.clone();
 58        let this = self.clone();
 59        let title = self.name().into();
 60
 61        cx.spawn(async move |cx| {
 62            let command = this.command(&project, cx).await?;
 63
 64            let mut child = util::command::new_smol_command(&command.path)
 65                .args(command.args.iter())
 66                .current_dir(root_dir)
 67                .stdin(std::process::Stdio::piped())
 68                .stdout(std::process::Stdio::piped())
 69                .stderr(std::process::Stdio::inherit())
 70                .kill_on_drop(true)
 71                .spawn()?;
 72
 73            let stdin = child.stdin.take().unwrap();
 74            let stdout = child.stdout.take().unwrap();
 75
 76            cx.new(|cx| {
 77                let foreground_executor = cx.foreground_executor().clone();
 78
 79                let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
 80                    AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
 81                    stdin,
 82                    stdout,
 83                    move |fut| foreground_executor.spawn(fut).detach(),
 84                );
 85
 86                let io_task = cx.background_spawn(async move {
 87                    io_fut.await.log_err();
 88                });
 89
 90                let child_status = cx.background_spawn(async move {
 91                    let result = match child.status().await {
 92                        Err(e) => Err(anyhow!(e)),
 93                        Ok(result) if result.success() => Ok(()),
 94                        Ok(result) => {
 95                            if let Some(AgentServerVersion::Unsupported {
 96                                error_message,
 97                                upgrade_message,
 98                                upgrade_command,
 99                            }) = this.version(&command).await.log_err()
100                            {
101                                Err(anyhow!(LoadError::Unsupported {
102                                    error_message,
103                                    upgrade_message,
104                                    upgrade_command
105                                }))
106                            } else {
107                                Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
108                            }
109                        }
110                    };
111                    drop(io_task);
112                    result
113                });
114
115                AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
116            })
117        })
118    }
119}