stdio_agent_server.rs

  1use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
  2use acp_thread::{AcpClientDelegate, AcpThread, LoadError};
  3use agentic_coding_protocol as acp;
  4use anyhow::{Result, anyhow};
  5use gpui::{App, AsyncApp, Entity, Task, prelude::*};
  6use project::Project;
  7use std::{
  8    path::{Path, PathBuf},
  9    sync::Arc,
 10};
 11use util::{ResultExt, paths};
 12
 13pub trait StdioAgentServer: Send + Clone {
 14    fn logo(&self) -> ui::IconName;
 15    fn name(&self) -> &'static str;
 16    fn empty_state_headline(&self) -> &'static str;
 17    fn empty_state_message(&self) -> &'static str;
 18    fn supports_always_allow(&self) -> bool;
 19
 20    fn command(
 21        &self,
 22        project: &Entity<Project>,
 23        cx: &mut AsyncApp,
 24    ) -> impl Future<Output = Result<AgentServerCommand>>;
 25
 26    fn version(
 27        &self,
 28        command: &AgentServerCommand,
 29    ) -> impl Future<Output = Result<AgentServerVersion>> + Send;
 30}
 31
 32impl<T: StdioAgentServer + 'static> AgentServer for T {
 33    fn name(&self) -> &'static str {
 34        self.name()
 35    }
 36
 37    fn empty_state_headline(&self) -> &'static str {
 38        self.empty_state_headline()
 39    }
 40
 41    fn empty_state_message(&self) -> &'static str {
 42        self.empty_state_message()
 43    }
 44
 45    fn logo(&self) -> ui::IconName {
 46        self.logo()
 47    }
 48
 49    fn supports_always_allow(&self) -> bool {
 50        self.supports_always_allow()
 51    }
 52
 53    fn new_thread(
 54        &self,
 55        root_dir: &Path,
 56        project: &Entity<Project>,
 57        cx: &mut App,
 58    ) -> Task<Result<Entity<AcpThread>>> {
 59        let root_dir = root_dir.to_path_buf();
 60        let project = project.clone();
 61        let this = self.clone();
 62        let title = self.name().into();
 63
 64        cx.spawn(async move |cx| {
 65            let command = this.command(&project, cx).await?;
 66
 67            let mut child = util::command::new_smol_command(&command.path)
 68                .args(command.args.iter())
 69                .current_dir(root_dir)
 70                .stdin(std::process::Stdio::piped())
 71                .stdout(std::process::Stdio::piped())
 72                .stderr(std::process::Stdio::inherit())
 73                .kill_on_drop(true)
 74                .spawn()?;
 75
 76            let stdin = child.stdin.take().unwrap();
 77            let stdout = child.stdout.take().unwrap();
 78
 79            cx.new(|cx| {
 80                let foreground_executor = cx.foreground_executor().clone();
 81
 82                let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
 83                    AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
 84                    stdin,
 85                    stdout,
 86                    move |fut| foreground_executor.spawn(fut).detach(),
 87                );
 88
 89                let io_task = cx.background_spawn(async move {
 90                    io_fut.await.log_err();
 91                });
 92
 93                let child_status = cx.background_spawn(async move {
 94                    let result = match child.status().await {
 95                        Err(e) => Err(anyhow!(e)),
 96                        Ok(result) if result.success() => Ok(()),
 97                        Ok(result) => {
 98                            if let Some(AgentServerVersion::Unsupported {
 99                                error_message,
100                                upgrade_message,
101                                upgrade_command,
102                            }) = this.version(&command).await.log_err()
103                            {
104                                Err(anyhow!(LoadError::Unsupported {
105                                    error_message,
106                                    upgrade_message,
107                                    upgrade_command
108                                }))
109                            } else {
110                                Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
111                            }
112                        }
113                    };
114                    drop(io_task);
115                    result
116                });
117
118                AcpThread::new(connection, title, Some(child_status), project.clone(), cx)
119            })
120        })
121    }
122}
123
124pub async fn find_bin_in_path(
125    bin_name: &'static str,
126    project: &Entity<Project>,
127    cx: &mut AsyncApp,
128) -> Option<PathBuf> {
129    let (env_task, root_dir) = project
130        .update(cx, |project, cx| {
131            let worktree = project.visible_worktrees(cx).next();
132            match worktree {
133                Some(worktree) => {
134                    let env_task = project.environment().update(cx, |env, cx| {
135                        env.get_worktree_environment(worktree.clone(), cx)
136                    });
137
138                    let path = worktree.read(cx).abs_path();
139                    (env_task, path)
140                }
141                None => {
142                    let path: Arc<Path> = paths::home_dir().as_path().into();
143                    let env_task = project.environment().update(cx, |env, cx| {
144                        env.get_directory_environment(path.clone(), cx)
145                    });
146                    (env_task, path)
147                }
148            }
149        })
150        .log_err()?;
151
152    cx.background_executor()
153        .spawn(async move {
154            let which_result = if cfg!(windows) {
155                which::which(bin_name)
156            } else {
157                let env = env_task.await.unwrap_or_default();
158                let shell_path = env.get("PATH").cloned();
159                which::which_in(bin_name, shell_path.as_ref(), root_dir.as_ref())
160            };
161
162            if let Err(which::Error::CannotFindBinaryPath) = which_result {
163                return None;
164            }
165
166            which_result.log_err()
167        })
168        .await
169}