1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
2use agentic_coding_protocol as acp;
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use collections::HashMap;
6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
7use parking_lot::Mutex;
8use project::Project;
9use smol::process::Child;
10use std::{process::ExitStatus, sync::Arc};
11use util::ResultExt;
12
13pub struct AcpServer {
14 connection: Arc<acp::AgentConnection>,
15 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
16 project: Entity<Project>,
17 exit_status: Arc<Mutex<Option<ExitStatus>>>,
18 _handler_task: Task<()>,
19 _io_task: Task<()>,
20}
21
22struct AcpClientDelegate {
23 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
24 cx: AsyncApp,
25 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
26}
27
28impl AcpClientDelegate {
29 fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
30 Self { threads, cx: cx }
31 }
32
33 fn update_thread<R>(
34 &self,
35 thread_id: &ThreadId,
36 cx: &mut App,
37 callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
38 ) -> Option<R> {
39 let thread = self.threads.lock().get(&thread_id)?.clone();
40 let Some(thread) = thread.upgrade() else {
41 self.threads.lock().remove(&thread_id);
42 return None;
43 };
44 Some(thread.update(cx, callback))
45 }
46}
47
48#[async_trait(?Send)]
49impl acp::Client for AcpClientDelegate {
50 async fn stream_assistant_message_chunk(
51 &self,
52 params: acp::StreamAssistantMessageChunkParams,
53 ) -> Result<acp::StreamAssistantMessageChunkResponse> {
54 let cx = &mut self.cx.clone();
55
56 cx.update(|cx| {
57 self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
58 thread.push_assistant_chunk(params.chunk, cx)
59 });
60 })?;
61
62 Ok(acp::StreamAssistantMessageChunkResponse)
63 }
64
65 async fn request_tool_call_confirmation(
66 &self,
67 request: acp::RequestToolCallConfirmationParams,
68 ) -> Result<acp::RequestToolCallConfirmationResponse> {
69 let cx = &mut self.cx.clone();
70 let ToolCallRequest { id, outcome } = cx
71 .update(|cx| {
72 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
73 thread.request_tool_call(
74 request.label,
75 request.icon,
76 request.content,
77 request.confirmation,
78 cx,
79 )
80 })
81 })?
82 .context("Failed to update thread")?;
83
84 Ok(acp::RequestToolCallConfirmationResponse {
85 id: id.into(),
86 outcome: outcome.await?,
87 })
88 }
89
90 async fn push_tool_call(
91 &self,
92 request: acp::PushToolCallParams,
93 ) -> Result<acp::PushToolCallResponse> {
94 let cx = &mut self.cx.clone();
95 let entry_id = cx
96 .update(|cx| {
97 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
98 thread.push_tool_call(request.label, request.icon, request.content, cx)
99 })
100 })?
101 .context("Failed to update thread")?;
102
103 Ok(acp::PushToolCallResponse {
104 id: entry_id.into(),
105 })
106 }
107
108 async fn update_tool_call(
109 &self,
110 request: acp::UpdateToolCallParams,
111 ) -> Result<acp::UpdateToolCallResponse> {
112 let cx = &mut self.cx.clone();
113
114 cx.update(|cx| {
115 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
116 thread.update_tool_call(
117 request.tool_call_id.into(),
118 request.status,
119 request.content,
120 cx,
121 )
122 })
123 })?
124 .context("Failed to update thread")??;
125
126 Ok(acp::UpdateToolCallResponse)
127 }
128}
129
130impl AcpServer {
131 pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
132 let stdin = process.stdin.take().expect("process didn't have stdin");
133 let stdout = process.stdout.take().expect("process didn't have stdout");
134
135 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
136 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
137 AcpClientDelegate::new(threads.clone(), cx.to_async()),
138 stdin,
139 stdout,
140 );
141
142 let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
143 let io_task = cx.background_spawn({
144 let exit_status = exit_status.clone();
145 async move {
146 io_fut.await.log_err();
147 let result = process.status().await.log_err();
148 *exit_status.lock() = result;
149 }
150 });
151
152 Arc::new(Self {
153 project,
154 connection: Arc::new(connection),
155 threads,
156 exit_status,
157 _handler_task: cx.foreground_executor().spawn(handler_fut),
158 _io_task: io_task,
159 })
160 }
161
162 #[cfg(test)]
163 pub fn fake(
164 stdin: async_pipe::PipeWriter,
165 stdout: async_pipe::PipeReader,
166 project: Entity<Project>,
167 cx: &mut App,
168 ) -> Arc<Self> {
169 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
170 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
171 AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
172 stdin,
173 stdout,
174 );
175
176 let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
177 let io_task = cx.background_spawn({
178 async move {
179 io_fut.await.log_err();
180 // todo!() exit status?
181 }
182 });
183
184 Arc::new(Self {
185 project,
186 connection: Arc::new(connection),
187 threads,
188 exit_status,
189 _handler_task: cx.foreground_executor().spawn(handler_fut),
190 _io_task: io_task,
191 })
192 }
193
194 pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
195 self.connection
196 .request(acp::InitializeParams)
197 .await
198 .map_err(to_anyhow)
199 }
200
201 pub async fn authenticate(&self) -> Result<()> {
202 self.connection
203 .request(acp::AuthenticateParams)
204 .await
205 .map_err(to_anyhow)?;
206
207 Ok(())
208 }
209
210 pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
211 let response = self
212 .connection
213 .request(acp::CreateThreadParams)
214 .await
215 .map_err(to_anyhow)?;
216
217 let thread_id: ThreadId = response.thread_id.into();
218 let server = self.clone();
219 let thread = cx.new(|cx| {
220 AcpThread::new(
221 server,
222 thread_id.clone(),
223 Vec::default(),
224 self.project.clone(),
225 cx,
226 )
227 })?;
228 self.threads.lock().insert(thread_id, thread.downgrade());
229 Ok(thread)
230 }
231
232 pub async fn send_message(
233 &self,
234 thread_id: ThreadId,
235 message: acp::UserMessage,
236 _cx: &mut AsyncApp,
237 ) -> Result<()> {
238 self.connection
239 .request(acp::SendUserMessageParams {
240 thread_id: thread_id.clone().into(),
241 message,
242 })
243 .await
244 .map_err(to_anyhow)?;
245 Ok(())
246 }
247
248 pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
249 self.connection
250 .request(acp::CancelSendMessageParams {
251 thread_id: thread_id.clone().into(),
252 })
253 .await
254 .map_err(to_anyhow)?;
255 Ok(())
256 }
257
258 pub fn exit_status(&self) -> Option<ExitStatus> {
259 *self.exit_status.lock()
260 }
261}
262
263#[track_caller]
264fn to_anyhow(e: acp::Error) -> anyhow::Error {
265 log::error!(
266 "failed to send message: {code}: {message}",
267 code = e.code,
268 message = e.message
269 );
270 anyhow::anyhow!(e.message)
271}
272
273impl From<acp::ThreadId> for ThreadId {
274 fn from(thread_id: acp::ThreadId) -> Self {
275 Self(thread_id.0.into())
276 }
277}
278
279impl From<ThreadId> for acp::ThreadId {
280 fn from(thread_id: ThreadId) -> Self {
281 acp::ThreadId(thread_id.0.to_string())
282 }
283}
284
285impl From<acp::ToolCallId> for ToolCallId {
286 fn from(tool_call_id: acp::ToolCallId) -> Self {
287 Self(ThreadEntryId(tool_call_id.0))
288 }
289}
290
291impl From<ToolCallId> for acp::ToolCallId {
292 fn from(tool_call_id: ToolCallId) -> Self {
293 acp::ToolCallId(tool_call_id.as_u64())
294 }
295}