1use agent_client_protocol::{self as acp, Agent as _};
2use collections::HashMap;
3use futures::channel::oneshot;
4use project::Project;
5use std::cell::RefCell;
6use std::path::Path;
7use std::rc::Rc;
8
9use anyhow::{Context as _, Result};
10use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
11
12use crate::{AgentServerCommand, acp::UnsupportedVersion};
13use acp_thread::{AcpThread, AgentConnection, AuthRequired};
14
15pub struct AcpConnection {
16 server_name: &'static str,
17 connection: Rc<acp::ClientSideConnection>,
18 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
19 auth_methods: Vec<acp::AuthMethod>,
20 _io_task: Task<Result<()>>,
21 _child: smol::process::Child,
22}
23
24pub struct AcpSession {
25 thread: WeakEntity<AcpThread>,
26}
27
28const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
29
30impl AcpConnection {
31 pub async fn stdio(
32 server_name: &'static str,
33 command: AgentServerCommand,
34 root_dir: &Path,
35 cx: &mut AsyncApp,
36 ) -> Result<Self> {
37 let mut child = util::command::new_smol_command(&command.path)
38 .args(command.args.iter().map(|arg| arg.as_str()))
39 .envs(command.env.iter().flatten())
40 .current_dir(root_dir)
41 .stdin(std::process::Stdio::piped())
42 .stdout(std::process::Stdio::piped())
43 .stderr(std::process::Stdio::inherit())
44 .kill_on_drop(true)
45 .spawn()?;
46
47 let stdout = child.stdout.take().expect("Failed to take stdout");
48 let stdin = child.stdin.take().expect("Failed to take stdin");
49
50 let sessions = Rc::new(RefCell::new(HashMap::default()));
51
52 let client = ClientDelegate {
53 sessions: sessions.clone(),
54 cx: cx.clone(),
55 };
56 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
57 let foreground_executor = cx.foreground_executor().clone();
58 move |fut| {
59 foreground_executor.spawn(fut).detach();
60 }
61 });
62
63 let io_task = cx.background_spawn(io_task);
64
65 let response = connection
66 .initialize(acp::InitializeRequest {
67 protocol_version: acp::VERSION,
68 client_capabilities: acp::ClientCapabilities {
69 fs: acp::FileSystemCapability {
70 read_text_file: true,
71 write_text_file: true,
72 },
73 },
74 })
75 .await?;
76
77 if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
78 return Err(UnsupportedVersion.into());
79 }
80
81 Ok(Self {
82 auth_methods: response.auth_methods,
83 connection: connection.into(),
84 server_name,
85 sessions,
86 _child: child,
87 _io_task: io_task,
88 })
89 }
90}
91
92impl AgentConnection for AcpConnection {
93 fn new_thread(
94 self: Rc<Self>,
95 project: Entity<Project>,
96 cwd: &Path,
97 cx: &mut AsyncApp,
98 ) -> Task<Result<Entity<AcpThread>>> {
99 let conn = self.connection.clone();
100 let sessions = self.sessions.clone();
101 let cwd = cwd.to_path_buf();
102 cx.spawn(async move |cx| {
103 let response = conn
104 .new_session(acp::NewSessionRequest {
105 mcp_servers: vec![],
106 cwd,
107 })
108 .await?;
109
110 let Some(session_id) = response.session_id else {
111 anyhow::bail!(AuthRequired);
112 };
113
114 let thread = cx.new(|cx| {
115 AcpThread::new(
116 self.server_name,
117 self.clone(),
118 project,
119 session_id.clone(),
120 cx,
121 )
122 })?;
123
124 let session = AcpSession {
125 thread: thread.downgrade(),
126 };
127 sessions.borrow_mut().insert(session_id, session);
128
129 Ok(thread)
130 })
131 }
132
133 fn auth_methods(&self) -> &[acp::AuthMethod] {
134 &self.auth_methods
135 }
136
137 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
138 let conn = self.connection.clone();
139 cx.foreground_executor().spawn(async move {
140 let result = conn
141 .authenticate(acp::AuthenticateRequest {
142 method_id: method_id.clone(),
143 })
144 .await?;
145
146 Ok(result)
147 })
148 }
149
150 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
151 let conn = self.connection.clone();
152 cx.foreground_executor()
153 .spawn(async move { Ok(conn.prompt(params).await?) })
154 }
155
156 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
157 let conn = self.connection.clone();
158 let params = acp::CancelledNotification {
159 session_id: session_id.clone(),
160 };
161 cx.foreground_executor()
162 .spawn(async move { conn.cancelled(params).await })
163 .detach();
164 }
165}
166
167struct ClientDelegate {
168 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
169 cx: AsyncApp,
170}
171
172impl acp::Client for ClientDelegate {
173 async fn request_permission(
174 &self,
175 arguments: acp::RequestPermissionRequest,
176 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
177 let cx = &mut self.cx.clone();
178 let rx = self
179 .sessions
180 .borrow()
181 .get(&arguments.session_id)
182 .context("Failed to get session")?
183 .thread
184 .update(cx, |thread, cx| {
185 thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
186 })?;
187
188 let result = rx.await;
189
190 let outcome = match result {
191 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
192 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
193 };
194
195 Ok(acp::RequestPermissionResponse { outcome })
196 }
197
198 async fn write_text_file(
199 &self,
200 arguments: acp::WriteTextFileRequest,
201 ) -> Result<(), acp::Error> {
202 let cx = &mut self.cx.clone();
203 let task = self
204 .sessions
205 .borrow()
206 .get(&arguments.session_id)
207 .context("Failed to get session")?
208 .thread
209 .update(cx, |thread, cx| {
210 thread.write_text_file(arguments.path, arguments.content, cx)
211 })?;
212
213 task.await?;
214
215 Ok(())
216 }
217
218 async fn read_text_file(
219 &self,
220 arguments: acp::ReadTextFileRequest,
221 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
222 let cx = &mut self.cx.clone();
223 let task = self
224 .sessions
225 .borrow()
226 .get(&arguments.session_id)
227 .context("Failed to get session")?
228 .thread
229 .update(cx, |thread, cx| {
230 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
231 })?;
232
233 let content = task.await?;
234
235 Ok(acp::ReadTextFileResponse { content })
236 }
237
238 async fn session_notification(
239 &self,
240 notification: acp::SessionNotification,
241 ) -> Result<(), acp::Error> {
242 let cx = &mut self.cx.clone();
243 let sessions = self.sessions.borrow();
244 let session = sessions
245 .get(¬ification.session_id)
246 .context("Failed to get session")?;
247
248 session.thread.update(cx, |thread, cx| {
249 thread.handle_session_update(notification.update, cx)
250 })??;
251
252 Ok(())
253 }
254}