1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::channel::oneshot;
5use project::Project;
6use std::cell::RefCell;
7use std::path::Path;
8use std::rc::Rc;
9
10use anyhow::{Context as _, Result};
11use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
12
13use crate::{AgentServerCommand, acp::UnsupportedVersion};
14use acp_thread::{AcpThread, AgentConnection, AuthRequired};
15
16pub struct AcpConnection {
17 server_name: &'static str,
18 connection: Rc<acp::ClientSideConnection>,
19 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
20 auth_methods: Vec<acp::AuthMethod>,
21 _io_task: Task<Result<()>>,
22}
23
24pub struct AcpSession {
25 thread: WeakEntity<AcpThread>,
26}
27
28const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
29
30impl AcpConnection {
31 pub async fn stdio(
32 server_name: &'static str,
33 command: AgentServerCommand,
34 root_dir: &Path,
35 cx: &mut AsyncApp,
36 ) -> Result<Self> {
37 let mut child = util::command::new_smol_command(&command.path)
38 .args(command.args.iter().map(|arg| arg.as_str()))
39 .envs(command.env.iter().flatten())
40 .current_dir(root_dir)
41 .stdin(std::process::Stdio::piped())
42 .stdout(std::process::Stdio::piped())
43 .stderr(std::process::Stdio::inherit())
44 .kill_on_drop(true)
45 .spawn()?;
46
47 let stdout = child.stdout.take().expect("Failed to take stdout");
48 let stdin = child.stdin.take().expect("Failed to take stdin");
49 log::trace!("Spawned (pid: {})", child.id());
50
51 let sessions = Rc::new(RefCell::new(HashMap::default()));
52
53 let client = ClientDelegate {
54 sessions: sessions.clone(),
55 cx: cx.clone(),
56 };
57 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
58 let foreground_executor = cx.foreground_executor().clone();
59 move |fut| {
60 foreground_executor.spawn(fut).detach();
61 }
62 });
63
64 let io_task = cx.background_spawn(io_task);
65
66 cx.spawn({
67 let sessions = sessions.clone();
68 async move |cx| {
69 let status = child.status().await?;
70
71 for session in sessions.borrow().values() {
72 session
73 .thread
74 .update(cx, |thread, cx| thread.emit_server_exited(status, cx))
75 .ok();
76 }
77
78 anyhow::Ok(())
79 }
80 })
81 .detach();
82
83 let response = connection
84 .initialize(acp::InitializeRequest {
85 protocol_version: acp::VERSION,
86 client_capabilities: acp::ClientCapabilities {
87 fs: acp::FileSystemCapability {
88 read_text_file: true,
89 write_text_file: true,
90 },
91 },
92 })
93 .await?;
94
95 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
96 return Err(UnsupportedVersion.into());
97 }
98
99 Ok(Self {
100 auth_methods: response.auth_methods,
101 connection: connection.into(),
102 server_name,
103 sessions,
104 _io_task: io_task,
105 })
106 }
107}
108
109impl AgentConnection for AcpConnection {
110 fn new_thread(
111 self: Rc<Self>,
112 project: Entity<Project>,
113 cwd: &Path,
114 cx: &mut App,
115 ) -> Task<Result<Entity<AcpThread>>> {
116 let conn = self.connection.clone();
117 let sessions = self.sessions.clone();
118 let cwd = cwd.to_path_buf();
119 cx.spawn(async move |cx| {
120 let response = conn
121 .new_session(acp::NewSessionRequest {
122 mcp_servers: vec![],
123 cwd,
124 })
125 .await
126 .map_err(|err| {
127 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
128 anyhow!(AuthRequired)
129 } else {
130 anyhow!(err)
131 }
132 })?;
133
134 let session_id = response.session_id;
135
136 let thread = cx.new(|cx| {
137 AcpThread::new(
138 self.server_name,
139 self.clone(),
140 project,
141 session_id.clone(),
142 cx,
143 )
144 })?;
145
146 let session = AcpSession {
147 thread: thread.downgrade(),
148 };
149 sessions.borrow_mut().insert(session_id, session);
150
151 Ok(thread)
152 })
153 }
154
155 fn auth_methods(&self) -> &[acp::AuthMethod] {
156 &self.auth_methods
157 }
158
159 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
160 let conn = self.connection.clone();
161 cx.foreground_executor().spawn(async move {
162 let result = conn
163 .authenticate(acp::AuthenticateRequest {
164 method_id: method_id.clone(),
165 })
166 .await?;
167
168 Ok(result)
169 })
170 }
171
172 fn prompt(
173 &self,
174 _id: Option<acp_thread::UserMessageId>,
175 params: acp::PromptRequest,
176 cx: &mut App,
177 ) -> Task<Result<acp::PromptResponse>> {
178 let conn = self.connection.clone();
179 cx.foreground_executor().spawn(async move {
180 let response = conn.prompt(params).await?;
181 Ok(response)
182 })
183 }
184
185 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
186 let conn = self.connection.clone();
187 let params = acp::CancelNotification {
188 session_id: session_id.clone(),
189 };
190 cx.foreground_executor()
191 .spawn(async move { conn.cancel(params).await })
192 .detach();
193 }
194}
195
196struct ClientDelegate {
197 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
198 cx: AsyncApp,
199}
200
201impl acp::Client for ClientDelegate {
202 async fn request_permission(
203 &self,
204 arguments: acp::RequestPermissionRequest,
205 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
206 let cx = &mut self.cx.clone();
207 let rx = self
208 .sessions
209 .borrow()
210 .get(&arguments.session_id)
211 .context("Failed to get session")?
212 .thread
213 .update(cx, |thread, cx| {
214 thread.request_tool_call_authorization(arguments.tool_call, arguments.options, cx)
215 })?;
216
217 let result = rx.await;
218
219 let outcome = match result {
220 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
221 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
222 };
223
224 Ok(acp::RequestPermissionResponse { outcome })
225 }
226
227 async fn write_text_file(
228 &self,
229 arguments: acp::WriteTextFileRequest,
230 ) -> Result<(), acp::Error> {
231 let cx = &mut self.cx.clone();
232 let task = self
233 .sessions
234 .borrow()
235 .get(&arguments.session_id)
236 .context("Failed to get session")?
237 .thread
238 .update(cx, |thread, cx| {
239 thread.write_text_file(arguments.path, arguments.content, cx)
240 })?;
241
242 task.await?;
243
244 Ok(())
245 }
246
247 async fn read_text_file(
248 &self,
249 arguments: acp::ReadTextFileRequest,
250 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
251 let cx = &mut self.cx.clone();
252 let task = self
253 .sessions
254 .borrow()
255 .get(&arguments.session_id)
256 .context("Failed to get session")?
257 .thread
258 .update(cx, |thread, cx| {
259 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
260 })?;
261
262 let content = task.await?;
263
264 Ok(acp::ReadTextFileResponse { content })
265 }
266
267 async fn session_notification(
268 &self,
269 notification: acp::SessionNotification,
270 ) -> Result<(), acp::Error> {
271 let cx = &mut self.cx.clone();
272 let sessions = self.sessions.borrow();
273 let session = sessions
274 .get(¬ification.session_id)
275 .context("Failed to get session")?;
276
277 session.thread.update(cx, |thread, cx| {
278 thread.handle_session_update(notification.update, cx)
279 })??;
280
281 Ok(())
282 }
283}