stdio_agent_server.rs

  1use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
  2use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
  3use agentic_coding_protocol as acp_old;
  4use anyhow::{Result, anyhow};
  5use gpui::{App, AsyncApp, Entity, Task, WeakEntity, prelude::*};
  6use project::Project;
  7use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc};
  8use util::ResultExt;
  9
 10pub trait StdioAgentServer: Send + Clone {
 11    fn logo(&self) -> ui::IconName;
 12    fn name(&self) -> &'static str;
 13    fn empty_state_headline(&self) -> &'static str;
 14    fn empty_state_message(&self) -> &'static str;
 15    fn supports_always_allow(&self) -> bool;
 16
 17    fn command(
 18        &self,
 19        project: &Entity<Project>,
 20        cx: &mut AsyncApp,
 21    ) -> impl Future<Output = Result<AgentServerCommand>>;
 22
 23    fn version(
 24        &self,
 25        command: &AgentServerCommand,
 26    ) -> impl Future<Output = Result<AgentServerVersion>> + Send;
 27}
 28
 29impl<T: StdioAgentServer + 'static> AgentServer for T {
 30    fn name(&self) -> &'static str {
 31        self.name()
 32    }
 33
 34    fn empty_state_headline(&self) -> &'static str {
 35        self.empty_state_headline()
 36    }
 37
 38    fn empty_state_message(&self) -> &'static str {
 39        self.empty_state_message()
 40    }
 41
 42    fn logo(&self) -> ui::IconName {
 43        self.logo()
 44    }
 45
 46    fn supports_always_allow(&self) -> bool {
 47        self.supports_always_allow()
 48    }
 49
 50    fn connect(
 51        &self,
 52        root_dir: &Path,
 53        project: &Entity<Project>,
 54        cx: &mut App,
 55    ) -> Task<Result<Arc<dyn AgentConnection>>> {
 56        let root_dir = root_dir.to_path_buf();
 57        let project = project.clone();
 58        let this = self.clone();
 59
 60        cx.spawn(async move |cx| {
 61            let command = this.command(&project, cx).await?;
 62
 63            let mut child = util::command::new_smol_command(&command.path)
 64                .args(command.args.iter())
 65                .current_dir(root_dir)
 66                .stdin(std::process::Stdio::piped())
 67                .stdout(std::process::Stdio::piped())
 68                .stderr(std::process::Stdio::inherit())
 69                .kill_on_drop(true)
 70                .spawn()?;
 71
 72            let stdin = child.stdin.take().unwrap();
 73            let stdout = child.stdout.take().unwrap();
 74
 75            let foreground_executor = cx.foreground_executor().clone();
 76
 77            let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
 78
 79            let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
 80                OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
 81                stdin,
 82                stdout,
 83                move |fut| foreground_executor.spawn(fut).detach(),
 84            );
 85
 86            let io_task = cx.background_spawn(async move {
 87                io_fut.await.log_err();
 88            });
 89
 90            let child_status = cx.background_spawn(async move {
 91                let result = match child.status().await {
 92                    Err(e) => Err(anyhow!(e)),
 93                    Ok(result) if result.success() => Ok(()),
 94                    Ok(result) => {
 95                        if let Some(AgentServerVersion::Unsupported {
 96                            error_message,
 97                            upgrade_message,
 98                            upgrade_command,
 99                        }) = this.version(&command).await.log_err()
100                        {
101                            Err(anyhow!(LoadError::Unsupported {
102                                error_message,
103                                upgrade_message,
104                                upgrade_command
105                            }))
106                        } else {
107                            Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
108                        }
109                    }
110                };
111                drop(io_task);
112                result
113            });
114
115            let connection: Arc<dyn AgentConnection> = Arc::new(OldAcpAgentConnection {
116                connection,
117                child_status,
118                thread: thread_rc,
119            });
120
121            Ok(connection)
122        })
123    }
124}