1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::AsyncBufReadExt as _;
5use futures::channel::oneshot;
6use futures::io::BufReader;
7use project::Project;
8use std::path::Path;
9use std::rc::Rc;
10use std::{any::Any, cell::RefCell};
11
12use anyhow::{Context as _, Result};
13use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
14
15use crate::{AgentServerCommand, acp::UnsupportedVersion};
16use acp_thread::{AcpThread, AgentConnection, AuthRequired};
17
18pub struct AcpConnection {
19 server_name: &'static str,
20 connection: Rc<acp::ClientSideConnection>,
21 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
22 auth_methods: Vec<acp::AuthMethod>,
23 _io_task: Task<Result<()>>,
24}
25
26pub struct AcpSession {
27 thread: WeakEntity<AcpThread>,
28}
29
30const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
31
32impl AcpConnection {
33 pub async fn stdio(
34 server_name: &'static str,
35 command: AgentServerCommand,
36 root_dir: &Path,
37 cx: &mut AsyncApp,
38 ) -> Result<Self> {
39 let mut child = util::command::new_smol_command(&command.path)
40 .args(command.args.iter().map(|arg| arg.as_str()))
41 .envs(command.env.iter().flatten())
42 .current_dir(root_dir)
43 .stdin(std::process::Stdio::piped())
44 .stdout(std::process::Stdio::piped())
45 .stderr(std::process::Stdio::piped())
46 .kill_on_drop(true)
47 .spawn()?;
48
49 let stdout = child.stdout.take().context("Failed to take stdout")?;
50 let stdin = child.stdin.take().context("Failed to take stdin")?;
51 let stderr = child.stderr.take().context("Failed to take stderr")?;
52 log::trace!("Spawned (pid: {})", child.id());
53
54 let sessions = Rc::new(RefCell::new(HashMap::default()));
55
56 let client = ClientDelegate {
57 sessions: sessions.clone(),
58 cx: cx.clone(),
59 };
60 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
61 let foreground_executor = cx.foreground_executor().clone();
62 move |fut| {
63 foreground_executor.spawn(fut).detach();
64 }
65 });
66
67 let io_task = cx.background_spawn(io_task);
68
69 cx.background_spawn(async move {
70 let mut stderr = BufReader::new(stderr);
71 let mut line = String::new();
72 while let Ok(n) = stderr.read_line(&mut line).await
73 && n > 0
74 {
75 log::warn!("agent stderr: {}", &line);
76 line.clear();
77 }
78 })
79 .detach();
80
81 cx.spawn({
82 let sessions = sessions.clone();
83 async move |cx| {
84 let status = child.status().await?;
85
86 for session in sessions.borrow().values() {
87 session
88 .thread
89 .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
90 .ok();
91 }
92
93 anyhow::Ok(())
94 }
95 })
96 .detach();
97
98 let response = connection
99 .initialize(acp::InitializeRequest {
100 protocol_version: acp::VERSION,
101 client_capabilities: acp::ClientCapabilities {
102 fs: acp::FileSystemCapability {
103 read_text_file: true,
104 write_text_file: true,
105 },
106 },
107 })
108 .await?;
109
110 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
111 return Err(UnsupportedVersion.into());
112 }
113
114 Ok(Self {
115 auth_methods: response.auth_methods,
116 connection: connection.into(),
117 server_name,
118 sessions,
119 _io_task: io_task,
120 })
121 }
122}
123
124impl AgentConnection for AcpConnection {
125 fn new_thread(
126 self: Rc<Self>,
127 project: Entity<Project>,
128 cwd: &Path,
129 cx: &mut App,
130 ) -> Task<Result<Entity<AcpThread>>> {
131 let conn = self.connection.clone();
132 let sessions = self.sessions.clone();
133 let cwd = cwd.to_path_buf();
134 cx.spawn(async move |cx| {
135 let response = conn
136 .new_session(acp::NewSessionRequest {
137 mcp_servers: vec![],
138 cwd,
139 })
140 .await
141 .map_err(|err| {
142 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
143 anyhow!(AuthRequired)
144 } else {
145 anyhow!(err)
146 }
147 })?;
148
149 let session_id = response.session_id;
150
151 let thread = cx.new(|cx| {
152 AcpThread::new(
153 self.server_name,
154 self.clone(),
155 project,
156 session_id.clone(),
157 cx,
158 )
159 })?;
160
161 let session = AcpSession {
162 thread: thread.downgrade(),
163 };
164 sessions.borrow_mut().insert(session_id, session);
165
166 Ok(thread)
167 })
168 }
169
170 fn auth_methods(&self) -> &[acp::AuthMethod] {
171 &self.auth_methods
172 }
173
174 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
175 let conn = self.connection.clone();
176 cx.foreground_executor().spawn(async move {
177 let result = conn
178 .authenticate(acp::AuthenticateRequest {
179 method_id: method_id.clone(),
180 })
181 .await?;
182
183 Ok(result)
184 })
185 }
186
187 fn prompt(
188 &self,
189 _id: Option<acp_thread::UserMessageId>,
190 params: acp::PromptRequest,
191 cx: &mut App,
192 ) -> Task<Result<acp::PromptResponse>> {
193 let conn = self.connection.clone();
194 cx.foreground_executor().spawn(async move {
195 let response = conn.prompt(params).await?;
196 Ok(response)
197 })
198 }
199
200 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
201 let conn = self.connection.clone();
202 let params = acp::CancelNotification {
203 session_id: session_id.clone(),
204 };
205 cx.foreground_executor()
206 .spawn(async move { conn.cancel(params).await })
207 .detach();
208 }
209
210 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
211 self
212 }
213}
214
215struct ClientDelegate {
216 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
217 cx: AsyncApp,
218}
219
220impl acp::Client for ClientDelegate {
221 async fn request_permission(
222 &self,
223 arguments: acp::RequestPermissionRequest,
224 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
225 let cx = &mut self.cx.clone();
226 let rx = self
227 .sessions
228 .borrow()
229 .get(&arguments.session_id)
230 .context("Failed to get session")?
231 .thread
232 .update(cx, |thread, cx| {
233 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
234 })?;
235
236 let result = rx?.await;
237
238 let outcome = match result {
239 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
240 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
241 };
242
243 Ok(acp::RequestPermissionResponse { outcome })
244 }
245
246 async fn write_text_file(
247 &self,
248 arguments: acp::WriteTextFileRequest,
249 ) -> Result<(), acp::Error> {
250 let cx = &mut self.cx.clone();
251 let task = self
252 .sessions
253 .borrow()
254 .get(&arguments.session_id)
255 .context("Failed to get session")?
256 .thread
257 .update(cx, |thread, cx| {
258 thread.write_text_file(arguments.path, arguments.content, cx)
259 })?;
260
261 task.await?;
262
263 Ok(())
264 }
265
266 async fn read_text_file(
267 &self,
268 arguments: acp::ReadTextFileRequest,
269 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
270 let cx = &mut self.cx.clone();
271 let task = self
272 .sessions
273 .borrow()
274 .get(&arguments.session_id)
275 .context("Failed to get session")?
276 .thread
277 .update(cx, |thread, cx| {
278 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
279 })?;
280
281 let content = task.await?;
282
283 Ok(acp::ReadTextFileResponse { content })
284 }
285
286 async fn session_notification(
287 &self,
288 notification: acp::SessionNotification,
289 ) -> Result<(), acp::Error> {
290 let cx = &mut self.cx.clone();
291 let sessions = self.sessions.borrow();
292 let session = sessions
293 .get(¬ification.session_id)
294 .context("Failed to get session")?;
295
296 session.thread.update(cx, |thread, cx| {
297 thread.handle_session_update(notification.update, cx)
298 })??;
299
300 Ok(())
301 }
302}