1use agent_client_protocol::{self as acp, Agent as _};
2use anyhow::anyhow;
3use collections::HashMap;
4use futures::channel::oneshot;
5use project::Project;
6use std::cell::RefCell;
7use std::path::Path;
8use std::rc::Rc;
9
10use anyhow::{Context as _, Result};
11use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
12
13use crate::{AgentServerCommand, acp::UnsupportedVersion};
14use acp_thread::{AcpThread, AgentConnection, AuthRequired};
15
16pub struct AcpConnection {
17 server_name: &'static str,
18 connection: Rc<acp::ClientSideConnection>,
19 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
20 auth_methods: Vec<acp::AuthMethod>,
21 _io_task: Task<Result<()>>,
22 _child: smol::process::Child,
23}
24
25pub struct AcpSession {
26 thread: WeakEntity<AcpThread>,
27}
28
29const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
30
31impl AcpConnection {
32 pub async fn stdio(
33 server_name: &'static str,
34 command: AgentServerCommand,
35 root_dir: &Path,
36 cx: &mut AsyncApp,
37 ) -> Result<Self> {
38 let mut child = util::command::new_smol_command(&command.path)
39 .args(command.args.iter().map(|arg| arg.as_str()))
40 .envs(command.env.iter().flatten())
41 .current_dir(root_dir)
42 .stdin(std::process::Stdio::piped())
43 .stdout(std::process::Stdio::piped())
44 .stderr(std::process::Stdio::inherit())
45 .kill_on_drop(true)
46 .spawn()?;
47
48 let stdout = child.stdout.take().expect("Failed to take stdout");
49 let stdin = child.stdin.take().expect("Failed to take stdin");
50
51 let sessions = Rc::new(RefCell::new(HashMap::default()));
52
53 let client = ClientDelegate {
54 sessions: sessions.clone(),
55 cx: cx.clone(),
56 };
57 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
58 let foreground_executor = cx.foreground_executor().clone();
59 move |fut| {
60 foreground_executor.spawn(fut).detach();
61 }
62 });
63
64 let io_task = cx.background_spawn(io_task);
65
66 let response = connection
67 .initialize(acp::InitializeRequest {
68 protocol_version: acp::VERSION,
69 client_capabilities: acp::ClientCapabilities {
70 fs: acp::FileSystemCapability {
71 read_text_file: true,
72 write_text_file: true,
73 },
74 },
75 })
76 .await?;
77
78 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
79 return Err(UnsupportedVersion.into());
80 }
81
82 Ok(Self {
83 auth_methods: response.auth_methods,
84 connection: connection.into(),
85 server_name,
86 sessions,
87 _child: child,
88 _io_task: io_task,
89 })
90 }
91}
92
93impl AgentConnection for AcpConnection {
94 fn new_thread(
95 self: Rc<Self>,
96 project: Entity<Project>,
97 cwd: &Path,
98 cx: &mut AsyncApp,
99 ) -> Task<Result<Entity<AcpThread>>> {
100 let conn = self.connection.clone();
101 let sessions = self.sessions.clone();
102 let cwd = cwd.to_path_buf();
103 cx.spawn(async move |cx| {
104 let response = conn
105 .new_session(acp::NewSessionRequest {
106 mcp_servers: vec![],
107 cwd,
108 })
109 .await
110 .map_err(|err| {
111 if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
112 anyhow!(AuthRequired)
113 } else {
114 anyhow!(err)
115 }
116 })?;
117
118 let session_id = response.session_id;
119
120 let thread = cx.new(|cx| {
121 AcpThread::new(
122 self.server_name,
123 self.clone(),
124 project,
125 session_id.clone(),
126 cx,
127 )
128 })?;
129
130 let session = AcpSession {
131 thread: thread.downgrade(),
132 };
133 sessions.borrow_mut().insert(session_id, session);
134
135 Ok(thread)
136 })
137 }
138
139 fn auth_methods(&self) -> &[acp::AuthMethod] {
140 &self.auth_methods
141 }
142
143 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
144 let conn = self.connection.clone();
145 cx.foreground_executor().spawn(async move {
146 let result = conn
147 .authenticate(acp::AuthenticateRequest {
148 method_id: method_id.clone(),
149 })
150 .await?;
151
152 Ok(result)
153 })
154 }
155
156 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
157 let conn = self.connection.clone();
158 cx.foreground_executor()
159 .spawn(async move { Ok(conn.prompt(params).await?) })
160 }
161
162 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
163 let conn = self.connection.clone();
164 let params = acp::CancelNotification {
165 session_id: session_id.clone(),
166 };
167 cx.foreground_executor()
168 .spawn(async move { conn.cancel(params).await })
169 .detach();
170 }
171}
172
173struct ClientDelegate {
174 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
175 cx: AsyncApp,
176}
177
178impl acp::Client for ClientDelegate {
179 async fn request_permission(
180 &self,
181 arguments: acp::RequestPermissionRequest,
182 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
183 let cx = &mut self.cx.clone();
184 let rx = self
185 .sessions
186 .borrow()
187 .get(&arguments.session_id)
188 .context("Failed to get session")?
189 .thread
190 .update(cx, |thread, cx| {
191 thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
192 })?;
193
194 let result = rx.await;
195
196 let outcome = match result {
197 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
198 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
199 };
200
201 Ok(acp::RequestPermissionResponse { outcome })
202 }
203
204 async fn write_text_file(
205 &self,
206 arguments: acp::WriteTextFileRequest,
207 ) -> Result<(), acp::Error> {
208 let cx = &mut self.cx.clone();
209 let task = self
210 .sessions
211 .borrow()
212 .get(&arguments.session_id)
213 .context("Failed to get session")?
214 .thread
215 .update(cx, |thread, cx| {
216 thread.write_text_file(arguments.path, arguments.content, cx)
217 })?;
218
219 task.await?;
220
221 Ok(())
222 }
223
224 async fn read_text_file(
225 &self,
226 arguments: acp::ReadTextFileRequest,
227 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
228 let cx = &mut self.cx.clone();
229 let task = self
230 .sessions
231 .borrow()
232 .get(&arguments.session_id)
233 .context("Failed to get session")?
234 .thread
235 .update(cx, |thread, cx| {
236 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
237 })?;
238
239 let content = task.await?;
240
241 Ok(acp::ReadTextFileResponse { content })
242 }
243
244 async fn session_notification(
245 &self,
246 notification: acp::SessionNotification,
247 ) -> Result<(), acp::Error> {
248 let cx = &mut self.cx.clone();
249 let sessions = self.sessions.borrow();
250 let session = sessions
251 .get(¬ification.session_id)
252 .context("Failed to get session")?;
253
254 session.thread.update(cx, |thread, cx| {
255 thread.handle_session_update(notification.update, cx)
256 })??;
257
258 Ok(())
259 }
260}