1use crate::{AcpThread, AgentThreadEntryContent, ThreadEntryId, ThreadId};
2use agentic_coding_protocol as acp;
3use anyhow::{Context as _, Result};
4use async_trait::async_trait;
5use collections::HashMap;
6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
7use parking_lot::Mutex;
8use project::Project;
9use smol::process::Child;
10use std::{io::Write as _, path::Path, sync::Arc};
11use util::ResultExt;
12
13pub struct AcpServer {
14 connection: Arc<acp::AgentConnection>,
15 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
16 project: Entity<Project>,
17 _handler_task: Task<()>,
18 _io_task: Task<()>,
19}
20
21struct AcpClientDelegate {
22 project: Entity<Project>,
23 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
24 cx: AsyncApp,
25 // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
26}
27
28impl AcpClientDelegate {
29 fn new(
30 project: Entity<Project>,
31 threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
32 cx: AsyncApp,
33 ) -> Self {
34 Self {
35 project,
36 threads,
37 cx: cx,
38 }
39 }
40
41 fn update_thread<R>(
42 &self,
43 thread_id: &ThreadId,
44 cx: &mut App,
45 callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
46 ) -> Option<R> {
47 let thread = self.threads.lock().get(&thread_id)?.clone();
48 let Some(thread) = thread.upgrade() else {
49 self.threads.lock().remove(&thread_id);
50 return None;
51 };
52 Some(thread.update(cx, callback))
53 }
54}
55
56#[async_trait(?Send)]
57impl acp::Client for AcpClientDelegate {
58 async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
59 let cx = &mut self.cx.clone();
60 self.project.update(cx, |project, cx| {
61 let path = project
62 .project_path_for_absolute_path(Path::new(¶ms.path), cx)
63 .context("Failed to get project path")?;
64
65 match project.entry_for_path(&path, cx) {
66 // todo! refresh entry?
67 None => Ok(acp::StatResponse {
68 exists: false,
69 is_directory: false,
70 }),
71 Some(entry) => Ok(acp::StatResponse {
72 exists: entry.is_created(),
73 is_directory: entry.is_dir(),
74 }),
75 }
76 })?
77 }
78
79 async fn stream_message_chunk(
80 &self,
81 params: acp::StreamMessageChunkParams,
82 ) -> Result<acp::StreamMessageChunkResponse> {
83 let cx = &mut self.cx.clone();
84
85 cx.update(|cx| {
86 self.update_thread(¶ms.thread_id.into(), cx, |thread, cx| {
87 thread.push_assistant_chunk(params.chunk, cx)
88 });
89 })?;
90
91 Ok(acp::StreamMessageChunkResponse)
92 }
93
94 async fn read_text_file(
95 &self,
96 request: acp::ReadTextFileParams,
97 ) -> Result<acp::ReadTextFileResponse> {
98 let cx = &mut self.cx.clone();
99 let buffer = self
100 .project
101 .update(cx, |project, cx| {
102 let path = project
103 .project_path_for_absolute_path(Path::new(&request.path), cx)
104 .context("Failed to get project path")?;
105 anyhow::Ok(project.open_buffer(path, cx))
106 })??
107 .await?;
108
109 buffer.update(cx, |buffer, cx| {
110 let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
111 let end = match request.line_limit {
112 None => buffer.max_point(),
113 Some(limit) => start + language::Point::new(limit + 1, 0),
114 };
115
116 let content: String = buffer.text_for_range(start..end).collect();
117 self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
118 thread.push_entry(
119 AgentThreadEntryContent::ReadFile {
120 path: request.path.clone(),
121 content: content.clone(),
122 },
123 cx,
124 );
125 });
126
127 acp::ReadTextFileResponse {
128 content,
129 version: acp::FileVersion(0),
130 }
131 })
132 }
133
134 async fn read_binary_file(
135 &self,
136 request: acp::ReadBinaryFileParams,
137 ) -> Result<acp::ReadBinaryFileResponse> {
138 let cx = &mut self.cx.clone();
139 let file = self
140 .project
141 .update(cx, |project, cx| {
142 let (worktree, path) = project
143 .find_worktree(Path::new(&request.path), cx)
144 .context("Failed to get project path")?;
145
146 let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
147 anyhow::Ok(task)
148 })??
149 .await?;
150
151 // todo! test
152 let content = cx
153 .background_spawn(async move {
154 let start = request.byte_offset.unwrap_or(0) as usize;
155 let end = request
156 .byte_limit
157 .map(|limit| (start + limit as usize).min(file.content.len()))
158 .unwrap_or(file.content.len());
159
160 let range_content = &file.content[start..end];
161
162 let mut base64_content = Vec::new();
163 let mut base64_encoder = base64::write::EncoderWriter::new(
164 std::io::Cursor::new(&mut base64_content),
165 &base64::engine::general_purpose::STANDARD,
166 );
167 base64_encoder.write_all(range_content)?;
168 drop(base64_encoder);
169
170 // SAFETY: The base64 encoder should not produce non-UTF8.
171 unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
172 })
173 .await?;
174
175 Ok(acp::ReadBinaryFileResponse {
176 content,
177 // todo!
178 version: acp::FileVersion(0),
179 })
180 }
181
182 async fn glob_search(
183 &self,
184 _request: acp::GlobSearchParams,
185 ) -> Result<acp::GlobSearchResponse> {
186 todo!()
187 }
188}
189
190impl AcpServer {
191 pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
192 let stdin = process.stdin.take().expect("process didn't have stdin");
193 let stdout = process.stdout.take().expect("process didn't have stdout");
194
195 let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
196 let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
197 AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
198 stdin,
199 stdout,
200 );
201
202 let io_task = cx.background_spawn(async move {
203 io_fut.await.log_err();
204 process.status().await.log_err();
205 });
206
207 Arc::new(Self {
208 project,
209 connection: Arc::new(connection),
210 threads,
211 _handler_task: cx.foreground_executor().spawn(handler_fut),
212 _io_task: io_task,
213 })
214 }
215}
216
217impl AcpServer {
218 pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
219 let response = self.connection.request(acp::CreateThreadParams).await?;
220 let thread_id: ThreadId = response.thread_id.into();
221 let server = self.clone();
222 let thread = cx.new(|_| AcpThread {
223 title: "The agent2 thread".into(),
224 id: thread_id.clone(),
225 next_entry_id: ThreadEntryId(0),
226 entries: Vec::default(),
227 project: self.project.clone(),
228 server,
229 })?;
230 self.threads.lock().insert(thread_id, thread.downgrade());
231 Ok(thread)
232 }
233
234 pub async fn send_message(
235 &self,
236 thread_id: ThreadId,
237 message: acp::Message,
238 _cx: &mut AsyncApp,
239 ) -> Result<()> {
240 self.connection
241 .request(acp::SendMessageParams {
242 thread_id: thread_id.clone().into(),
243 message,
244 })
245 .await?;
246 Ok(())
247 }
248}
249
250impl From<acp::ThreadId> for ThreadId {
251 fn from(thread_id: acp::ThreadId) -> Self {
252 Self(thread_id.0.into())
253 }
254}
255
256impl From<ThreadId> for acp::ThreadId {
257 fn from(thread_id: ThreadId) -> Self {
258 acp::ThreadId(thread_id.0.to_string())
259 }
260}