1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::channel::oneshot;
5use project::Project;
6use std::cell::RefCell;
7use std::path::Path;
8use std::rc::Rc;
9
10use anyhow::{Context as _, Result};
11use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
12
13use crate::{AgentServerCommand, acp::UnsupportedVersion};
14use acp_thread::{AcpThread, AgentConnection, AuthRequired};
15
16pub struct AcpConnection {
17 server_name: &'static str,
18 connection: Rc<acp::ClientSideConnection>,
19 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
20 auth_methods: Vec<acp::AuthMethod>,
21 _io_task: Task<Result<()>>,
22}
23
24pub struct AcpSession {
25 thread: WeakEntity<AcpThread>,
26}
27
28const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
29
30impl AcpConnection {
31 pub async fn stdio(
32 server_name: &'static str,
33 command: AgentServerCommand,
34 root_dir: &Path,
35 cx: &mut AsyncApp,
36 ) -> Result<Self> {
37 let mut child = util::command::new_smol_command(&command.path)
38 .args(command.args.iter().map(|arg| arg.as_str()))
39 .envs(command.env.iter().flatten())
40 .current_dir(root_dir)
41 .stdin(std::process::Stdio::piped())
42 .stdout(std::process::Stdio::piped())
43 .stderr(std::process::Stdio::inherit())
44 .kill_on_drop(true)
45 .spawn()?;
46
47 let stdout = child.stdout.take().expect("Failed to take stdout");
48 let stdin = child.stdin.take().expect("Failed to take stdin");
49 log::trace!("Spawned (pid: {})", child.id());
50
51 let sessions = Rc::new(RefCell::new(HashMap::default()));
52
53 let client = ClientDelegate {
54 sessions: sessions.clone(),
55 cx: cx.clone(),
56 };
57 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
58 let foreground_executor = cx.foreground_executor().clone();
59 move |fut| {
60 foreground_executor.spawn(fut).detach();
61 }
62 });
63
64 let io_task = cx.background_spawn(async move {
65 io_task.await?;
66 drop(child);
67 Ok(())
68 });
69
70 let response = connection
71 .initialize(acp::InitializeRequest {
72 protocol_version: acp::VERSION,
73 client_capabilities: acp::ClientCapabilities {
74 fs: acp::FileSystemCapability {
75 read_text_file: true,
76 write_text_file: true,
77 },
78 },
79 })
80 .await?;
81
82 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
83 return Err(UnsupportedVersion.into());
84 }
85
86 Ok(Self {
87 auth_methods: response.auth_methods,
88 connection: connection.into(),
89 server_name,
90 sessions,
91 _io_task: io_task,
92 })
93 }
94}
95
96impl AgentConnection for AcpConnection {
97 fn new_thread(
98 self: Rc<Self>,
99 project: Entity<Project>,
100 cwd: &Path,
101 cx: &mut AsyncApp,
102 ) -> Task<Result<Entity<AcpThread>>> {
103 let conn = self.connection.clone();
104 let sessions = self.sessions.clone();
105 let cwd = cwd.to_path_buf();
106 cx.spawn(async move |cx| {
107 let response = conn
108 .new_session(acp::NewSessionRequest {
109 mcp_servers: vec![],
110 cwd,
111 })
112 .await
113 .map_err(|err| {
114 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
115 anyhow!(AuthRequired)
116 } else {
117 anyhow!(err)
118 }
119 })?;
120
121 let session_id = response.session_id;
122
123 let thread = cx.new(|cx| {
124 AcpThread::new(
125 self.server_name,
126 self.clone(),
127 project,
128 session_id.clone(),
129 cx,
130 )
131 })?;
132
133 let session = AcpSession {
134 thread: thread.downgrade(),
135 };
136 sessions.borrow_mut().insert(session_id, session);
137
138 Ok(thread)
139 })
140 }
141
142 fn auth_methods(&self) -> &[acp::AuthMethod] {
143 &self.auth_methods
144 }
145
146 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
147 let conn = self.connection.clone();
148 cx.foreground_executor().spawn(async move {
149 let result = conn
150 .authenticate(acp::AuthenticateRequest {
151 method_id: method_id.clone(),
152 })
153 .await?;
154
155 Ok(result)
156 })
157 }
158
159 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
160 let conn = self.connection.clone();
161 cx.foreground_executor().spawn(async move {
162 conn.prompt(params).await?;
163 Ok(())
164 })
165 }
166
167 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
168 let conn = self.connection.clone();
169 let params = acp::CancelNotification {
170 session_id: session_id.clone(),
171 };
172 cx.foreground_executor()
173 .spawn(async move { conn.cancel(params).await })
174 .detach();
175 }
176}
177
178struct ClientDelegate {
179 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
180 cx: AsyncApp,
181}
182
183impl acp::Client for ClientDelegate {
184 async fn request_permission(
185 &self,
186 arguments: acp::RequestPermissionRequest,
187 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
188 let cx = &mut self.cx.clone();
189 let rx = self
190 .sessions
191 .borrow()
192 .get(&arguments.session_id)
193 .context("Failed to get session")?
194 .thread
195 .update(cx, |thread, cx| {
196 thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
197 })?;
198
199 let result = rx.await;
200
201 let outcome = match result {
202 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
203 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
204 };
205
206 Ok(acp::RequestPermissionResponse { outcome })
207 }
208
209 async fn write_text_file(
210 &self,
211 arguments: acp::WriteTextFileRequest,
212 ) -> Result<(), acp::Error> {
213 let cx = &mut self.cx.clone();
214 let task = self
215 .sessions
216 .borrow()
217 .get(&arguments.session_id)
218 .context("Failed to get session")?
219 .thread
220 .update(cx, |thread, cx| {
221 thread.write_text_file(arguments.path, arguments.content, cx)
222 })?;
223
224 task.await?;
225
226 Ok(())
227 }
228
229 async fn read_text_file(
230 &self,
231 arguments: acp::ReadTextFileRequest,
232 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
233 let cx = &mut self.cx.clone();
234 let task = self
235 .sessions
236 .borrow()
237 .get(&arguments.session_id)
238 .context("Failed to get session")?
239 .thread
240 .update(cx, |thread, cx| {
241 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
242 })?;
243
244 let content = task.await?;
245
246 Ok(acp::ReadTextFileResponse { content })
247 }
248
249 async fn session_notification(
250 &self,
251 notification: acp::SessionNotification,
252 ) -> Result<(), acp::Error> {
253 let cx = &mut self.cx.clone();
254 let sessions = self.sessions.borrow();
255 let session = sessions
256 .get(¬ification.session_id)
257 .context("Failed to get session")?;
258
259 session.thread.update(cx, |thread, cx| {
260 thread.handle_session_update(notification.update, cx)
261 })??;
262
263 Ok(())
264 }
265}