1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::AsyncBufReadExt as _;
5use futures::channel::oneshot;
6use futures::io::BufReader;
7use project::Project;
8use std::path::Path;
9use std::rc::Rc;
10use std::{any::Any, cell::RefCell};
11
12use anyhow::{Context as _, Result};
13use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
14
15use crate::{AgentServerCommand, acp::UnsupportedVersion};
16use acp_thread::{AcpThread, AgentConnection, AuthRequired};
17
18pub struct AcpConnection {
19 server_name: &'static str,
20 connection: Rc<acp::ClientSideConnection>,
21 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
22 auth_methods: Vec<acp::AuthMethod>,
23 _io_task: Task<Result<()>>,
24}
25
26pub struct AcpSession {
27 thread: WeakEntity<AcpThread>,
28}
29
30const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
31
32impl AcpConnection {
33 pub async fn stdio(
34 server_name: &'static str,
35 command: AgentServerCommand,
36 root_dir: &Path,
37 cx: &mut AsyncApp,
38 ) -> Result<Self> {
39 let mut child = util::command::new_smol_command(&command.path)
40 .args(command.args.iter().map(|arg| arg.as_str()))
41 .envs(command.env.iter().flatten())
42 .current_dir(root_dir)
43 .stdin(std::process::Stdio::piped())
44 .stdout(std::process::Stdio::piped())
45 .stderr(std::process::Stdio::piped())
46 .kill_on_drop(true)
47 .spawn()?;
48
49 let stdout = child.stdout.take().context("Failed to take stdout")?;
50 let stdin = child.stdin.take().context("Failed to take stdin")?;
51 let stderr = child.stderr.take().context("Failed to take stderr")?;
52 log::trace!("Spawned (pid: {})", child.id());
53
54 let sessions = Rc::new(RefCell::new(HashMap::default()));
55
56 let client = ClientDelegate {
57 sessions: sessions.clone(),
58 cx: cx.clone(),
59 };
60 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
61 let foreground_executor = cx.foreground_executor().clone();
62 move |fut| {
63 foreground_executor.spawn(fut).detach();
64 }
65 });
66
67 let io_task = cx.background_spawn(io_task);
68
69 cx.background_spawn(async move {
70 let mut stderr = BufReader::new(stderr);
71 let mut line = String::new();
72 while let Ok(n) = stderr.read_line(&mut line).await
73 && n > 0
74 {
75 log::warn!("agent stderr: {}", &line);
76 line.clear();
77 }
78 })
79 .detach();
80
81 cx.spawn({
82 let sessions = sessions.clone();
83 async move |cx| {
84 let status = child.status().await?;
85
86 for session in sessions.borrow().values() {
87 session
88 .thread
89 .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
90 .ok();
91 }
92
93 anyhow::Ok(())
94 }
95 })
96 .detach();
97
98 let response = connection
99 .initialize(acp::InitializeRequest {
100 protocol_version: acp::VERSION,
101 client_capabilities: acp::ClientCapabilities {
102 fs: acp::FileSystemCapability {
103 read_text_file: true,
104 write_text_file: true,
105 },
106 },
107 })
108 .await?;
109
110 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
111 return Err(UnsupportedVersion.into());
112 }
113
114 Ok(Self {
115 auth_methods: response.auth_methods,
116 connection: connection.into(),
117 server_name,
118 sessions,
119 _io_task: io_task,
120 })
121 }
122}
123
124impl AgentConnection for AcpConnection {
125 fn new_thread(
126 self: Rc<Self>,
127 project: Entity<Project>,
128 cwd: &Path,
129 cx: &mut App,
130 ) -> Task<Result<Entity<AcpThread>>> {
131 let conn = self.connection.clone();
132 let sessions = self.sessions.clone();
133 let cwd = cwd.to_path_buf();
134 cx.spawn(async move |cx| {
135 let response = conn
136 .new_session(acp::NewSessionRequest {
137 mcp_servers: vec![],
138 cwd,
139 })
140 .await
141 .map_err(|err| {
142 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
143 let mut error = AuthRequired::new();
144
145 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
146 error = error.with_description(err.message);
147 }
148
149 anyhow!(error)
150 } else {
151 anyhow!(err)
152 }
153 })?;
154
155 let session_id = response.session_id;
156
157 let thread = cx.new(|cx| {
158 AcpThread::new(
159 self.server_name,
160 self.clone(),
161 project,
162 session_id.clone(),
163 cx,
164 )
165 })?;
166
167 let session = AcpSession {
168 thread: thread.downgrade(),
169 };
170 sessions.borrow_mut().insert(session_id, session);
171
172 Ok(thread)
173 })
174 }
175
176 fn auth_methods(&self) -> &[acp::AuthMethod] {
177 &self.auth_methods
178 }
179
180 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
181 let conn = self.connection.clone();
182 cx.foreground_executor().spawn(async move {
183 let result = conn
184 .authenticate(acp::AuthenticateRequest {
185 method_id: method_id.clone(),
186 })
187 .await?;
188
189 Ok(result)
190 })
191 }
192
193 fn prompt(
194 &self,
195 _id: Option<acp_thread::UserMessageId>,
196 params: acp::PromptRequest,
197 cx: &mut App,
198 ) -> Task<Result<acp::PromptResponse>> {
199 let conn = self.connection.clone();
200 cx.foreground_executor().spawn(async move {
201 let response = conn.prompt(params).await?;
202 Ok(response)
203 })
204 }
205
206 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
207 let conn = self.connection.clone();
208 let params = acp::CancelNotification {
209 session_id: session_id.clone(),
210 };
211 cx.foreground_executor()
212 .spawn(async move { conn.cancel(params).await })
213 .detach();
214 }
215
216 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
217 self
218 }
219}
220
221struct ClientDelegate {
222 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
223 cx: AsyncApp,
224}
225
226impl acp::Client for ClientDelegate {
227 async fn request_permission(
228 &self,
229 arguments: acp::RequestPermissionRequest,
230 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
231 let cx = &mut self.cx.clone();
232 let rx = self
233 .sessions
234 .borrow()
235 .get(&arguments.session_id)
236 .context("Failed to get session")?
237 .thread
238 .update(cx, |thread, cx| {
239 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
240 })?;
241
242 let result = rx?.await;
243
244 let outcome = match result {
245 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
246 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
247 };
248
249 Ok(acp::RequestPermissionResponse { outcome })
250 }
251
252 async fn write_text_file(
253 &self,
254 arguments: acp::WriteTextFileRequest,
255 ) -> Result<(), acp::Error> {
256 let cx = &mut self.cx.clone();
257 let task = self
258 .sessions
259 .borrow()
260 .get(&arguments.session_id)
261 .context("Failed to get session")?
262 .thread
263 .update(cx, |thread, cx| {
264 thread.write_text_file(arguments.path, arguments.content, cx)
265 })?;
266
267 task.await?;
268
269 Ok(())
270 }
271
272 async fn read_text_file(
273 &self,
274 arguments: acp::ReadTextFileRequest,
275 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
276 let cx = &mut self.cx.clone();
277 let task = self
278 .sessions
279 .borrow()
280 .get(&arguments.session_id)
281 .context("Failed to get session")?
282 .thread
283 .update(cx, |thread, cx| {
284 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
285 })?;
286
287 let content = task.await?;
288
289 Ok(acp::ReadTextFileResponse { content })
290 }
291
292 async fn session_notification(
293 &self,
294 notification: acp::SessionNotification,
295 ) -> Result<(), acp::Error> {
296 let cx = &mut self.cx.clone();
297 let sessions = self.sessions.borrow();
298 let session = sessions
299 .get(¬ification.session_id)
300 .context("Failed to get session")?;
301
302 session.thread.update(cx, |thread, cx| {
303 thread.handle_session_update(notification.update, cx)
304 })??;
305
306 Ok(())
307 }
308}