1use action_log::ActionLog;
2use agent_client_protocol::{self as acp, Agent as _};
3use anyhow::anyhow;
4use collections::HashMap;
5use futures::AsyncBufReadExt as _;
6use futures::channel::oneshot;
7use futures::io::BufReader;
8use project::Project;
9use std::path::Path;
10use std::rc::Rc;
11use std::{any::Any, cell::RefCell};
12
13use anyhow::{Context as _, Result};
14use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
15
16use crate::{AgentServerCommand, acp::UnsupportedVersion};
17use acp_thread::{AcpThread, AgentConnection, AuthRequired};
18
19pub struct AcpConnection {
20 server_name: &'static str,
21 connection: Rc<acp::ClientSideConnection>,
22 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
23 auth_methods: Vec<acp::AuthMethod>,
24 _io_task: Task<Result<()>>,
25}
26
27pub struct AcpSession {
28 thread: WeakEntity<AcpThread>,
29}
30
31const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
32
33impl AcpConnection {
34 pub async fn stdio(
35 server_name: &'static str,
36 command: AgentServerCommand,
37 root_dir: &Path,
38 cx: &mut AsyncApp,
39 ) -> Result<Self> {
40 let mut child = util::command::new_smol_command(&command.path)
41 .args(command.args.iter().map(|arg| arg.as_str()))
42 .envs(command.env.iter().flatten())
43 .current_dir(root_dir)
44 .stdin(std::process::Stdio::piped())
45 .stdout(std::process::Stdio::piped())
46 .stderr(std::process::Stdio::piped())
47 .kill_on_drop(true)
48 .spawn()?;
49
50 let stdout = child.stdout.take().context("Failed to take stdout")?;
51 let stdin = child.stdin.take().context("Failed to take stdin")?;
52 let stderr = child.stderr.take().context("Failed to take stderr")?;
53 log::trace!("Spawned (pid: {})", child.id());
54
55 let sessions = Rc::new(RefCell::new(HashMap::default()));
56
57 let client = ClientDelegate {
58 sessions: sessions.clone(),
59 cx: cx.clone(),
60 };
61 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
62 let foreground_executor = cx.foreground_executor().clone();
63 move |fut| {
64 foreground_executor.spawn(fut).detach();
65 }
66 });
67
68 let io_task = cx.background_spawn(io_task);
69
70 cx.background_spawn(async move {
71 let mut stderr = BufReader::new(stderr);
72 let mut line = String::new();
73 while let Ok(n) = stderr.read_line(&mut line).await
74 && n > 0
75 {
76 log::warn!("agent stderr: {}", &line);
77 line.clear();
78 }
79 })
80 .detach();
81
82 cx.spawn({
83 let sessions = sessions.clone();
84 async move |cx| {
85 let status = child.status().await?;
86
87 for session in sessions.borrow().values() {
88 session
89 .thread
90 .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
91 .ok();
92 }
93
94 anyhow::Ok(())
95 }
96 })
97 .detach();
98
99 let response = connection
100 .initialize(acp::InitializeRequest {
101 protocol_version: acp::VERSION,
102 client_capabilities: acp::ClientCapabilities {
103 fs: acp::FileSystemCapability {
104 read_text_file: true,
105 write_text_file: true,
106 },
107 },
108 })
109 .await?;
110
111 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
112 return Err(UnsupportedVersion.into());
113 }
114
115 Ok(Self {
116 auth_methods: response.auth_methods,
117 connection: connection.into(),
118 server_name,
119 sessions,
120 _io_task: io_task,
121 })
122 }
123}
124
125impl AgentConnection for AcpConnection {
126 fn new_thread(
127 self: Rc<Self>,
128 project: Entity<Project>,
129 cwd: &Path,
130 cx: &mut App,
131 ) -> Task<Result<Entity<AcpThread>>> {
132 let conn = self.connection.clone();
133 let sessions = self.sessions.clone();
134 let cwd = cwd.to_path_buf();
135 cx.spawn(async move |cx| {
136 let response = conn
137 .new_session(acp::NewSessionRequest {
138 mcp_servers: vec![],
139 cwd,
140 })
141 .await
142 .map_err(|err| {
143 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
144 let mut error = AuthRequired::new();
145
146 if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
147 error = error.with_description(err.message);
148 }
149
150 anyhow!(error)
151 } else {
152 anyhow!(err)
153 }
154 })?;
155
156 let session_id = response.session_id;
157 let action_log = cx.new(|_| ActionLog::new(project.clone()))?;
158 let thread = cx.new(|_cx| {
159 AcpThread::new(
160 self.server_name,
161 self.clone(),
162 project,
163 action_log,
164 session_id.clone(),
165 )
166 })?;
167
168 let session = AcpSession {
169 thread: thread.downgrade(),
170 };
171 sessions.borrow_mut().insert(session_id, session);
172
173 Ok(thread)
174 })
175 }
176
177 fn auth_methods(&self) -> &[acp::AuthMethod] {
178 &self.auth_methods
179 }
180
181 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
182 let conn = self.connection.clone();
183 cx.foreground_executor().spawn(async move {
184 let result = conn
185 .authenticate(acp::AuthenticateRequest {
186 method_id: method_id.clone(),
187 })
188 .await?;
189
190 Ok(result)
191 })
192 }
193
194 fn prompt(
195 &self,
196 _id: Option<acp_thread::UserMessageId>,
197 params: acp::PromptRequest,
198 cx: &mut App,
199 ) -> Task<Result<acp::PromptResponse>> {
200 let conn = self.connection.clone();
201 cx.foreground_executor().spawn(async move {
202 let response = conn.prompt(params).await?;
203 Ok(response)
204 })
205 }
206
207 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
208 let conn = self.connection.clone();
209 let params = acp::CancelNotification {
210 session_id: session_id.clone(),
211 };
212 cx.foreground_executor()
213 .spawn(async move { conn.cancel(params).await })
214 .detach();
215 }
216
217 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
218 self
219 }
220}
221
222struct ClientDelegate {
223 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
224 cx: AsyncApp,
225}
226
227impl acp::Client for ClientDelegate {
228 async fn request_permission(
229 &self,
230 arguments: acp::RequestPermissionRequest,
231 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
232 let cx = &mut self.cx.clone();
233 let rx = self
234 .sessions
235 .borrow()
236 .get(&arguments.session_id)
237 .context("Failed to get session")?
238 .thread
239 .update(cx, |thread, cx| {
240 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
241 })?;
242
243 let result = rx?.await;
244
245 let outcome = match result {
246 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
247 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
248 };
249
250 Ok(acp::RequestPermissionResponse { outcome })
251 }
252
253 async fn write_text_file(
254 &self,
255 arguments: acp::WriteTextFileRequest,
256 ) -> Result<(), acp::Error> {
257 let cx = &mut self.cx.clone();
258 let task = self
259 .sessions
260 .borrow()
261 .get(&arguments.session_id)
262 .context("Failed to get session")?
263 .thread
264 .update(cx, |thread, cx| {
265 thread.write_text_file(arguments.path, arguments.content, cx)
266 })?;
267
268 task.await?;
269
270 Ok(())
271 }
272
273 async fn read_text_file(
274 &self,
275 arguments: acp::ReadTextFileRequest,
276 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
277 let cx = &mut self.cx.clone();
278 let task = self
279 .sessions
280 .borrow()
281 .get(&arguments.session_id)
282 .context("Failed to get session")?
283 .thread
284 .update(cx, |thread, cx| {
285 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
286 })?;
287
288 let content = task.await?;
289
290 Ok(acp::ReadTextFileResponse { content })
291 }
292
293 async fn session_notification(
294 &self,
295 notification: acp::SessionNotification,
296 ) -> Result<(), acp::Error> {
297 let cx = &mut self.cx.clone();
298 let sessions = self.sessions.borrow();
299 let session = sessions
300 .get(¬ification.session_id)
301 .context("Failed to get session")?;
302
303 session.thread.update(cx, |thread, cx| {
304 thread.handle_session_update(notification.update, cx)
305 })??;
306
307 Ok(())
308 }
309}