server.rs

  1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  7use parking_lot::Mutex;
  8use project::Project;
  9use smol::process::Child;
 10use std::{io::Write as _, path::Path, process::ExitStatus, sync::Arc};
 11use util::ResultExt;
 12
 13pub struct AcpServer {
 14    connection: Arc<acp::AgentConnection>,
 15    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 16    project: Entity<Project>,
 17    exit_status: Arc<Mutex<Option<ExitStatus>>>,
 18    _handler_task: Task<()>,
 19    _io_task: Task<()>,
 20}
 21
 22struct AcpClientDelegate {
 23    project: Entity<Project>,
 24    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 25    cx: AsyncApp,
 26    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 27}
 28
 29impl AcpClientDelegate {
 30    fn new(
 31        project: Entity<Project>,
 32        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 33        cx: AsyncApp,
 34    ) -> Self {
 35        Self {
 36            project,
 37            threads,
 38            cx: cx,
 39        }
 40    }
 41
 42    fn update_thread<R>(
 43        &self,
 44        thread_id: &ThreadId,
 45        cx: &mut App,
 46        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
 47    ) -> Option<R> {
 48        let thread = self.threads.lock().get(&thread_id)?.clone();
 49        let Some(thread) = thread.upgrade() else {
 50            self.threads.lock().remove(&thread_id);
 51            return None;
 52        };
 53        Some(thread.update(cx, callback))
 54    }
 55}
 56
 57#[async_trait(?Send)]
 58impl acp::Client for AcpClientDelegate {
 59    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 60        let cx = &mut self.cx.clone();
 61        self.project.update(cx, |project, cx| {
 62            let path = project
 63                .project_path_for_absolute_path(Path::new(&params.path), cx)
 64                .context("Failed to get project path")?;
 65
 66            match project.entry_for_path(&path, cx) {
 67                // todo! refresh entry?
 68                None => Ok(acp::StatResponse {
 69                    exists: false,
 70                    is_directory: false,
 71                    size: 0,
 72                }),
 73                Some(entry) => Ok(acp::StatResponse {
 74                    exists: entry.is_created(),
 75                    is_directory: entry.is_dir(),
 76                    size: entry.size,
 77                }),
 78            }
 79        })?
 80    }
 81
 82    async fn stream_message_chunk(
 83        &self,
 84        params: acp::StreamMessageChunkParams,
 85    ) -> Result<acp::StreamMessageChunkResponse> {
 86        let cx = &mut self.cx.clone();
 87
 88        cx.update(|cx| {
 89            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 90                thread.push_assistant_chunk(params.chunk, cx)
 91            });
 92        })?;
 93
 94        Ok(acp::StreamMessageChunkResponse)
 95    }
 96
 97    async fn read_text_file(
 98        &self,
 99        request: acp::ReadTextFileParams,
100    ) -> Result<acp::ReadTextFileResponse> {
101        let cx = &mut self.cx.clone();
102        let buffer = self
103            .project
104            .update(cx, |project, cx| {
105                let path = project
106                    .project_path_for_absolute_path(Path::new(&request.path), cx)
107                    .context("Failed to get project path")?;
108                anyhow::Ok(project.open_buffer(path, cx))
109            })??
110            .await?;
111
112        buffer.update(cx, |buffer, _cx| {
113            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
114            let end = match request.line_limit {
115                None => buffer.max_point(),
116                Some(limit) => start + language::Point::new(limit + 1, 0),
117            };
118
119            let content: String = buffer.text_for_range(start..end).collect();
120
121            acp::ReadTextFileResponse {
122                content,
123                version: acp::FileVersion(0),
124            }
125        })
126    }
127
128    async fn read_binary_file(
129        &self,
130        request: acp::ReadBinaryFileParams,
131    ) -> Result<acp::ReadBinaryFileResponse> {
132        let cx = &mut self.cx.clone();
133        let file = self
134            .project
135            .update(cx, |project, cx| {
136                let (worktree, path) = project
137                    .find_worktree(Path::new(&request.path), cx)
138                    .context("Failed to get project path")?;
139
140                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
141                anyhow::Ok(task)
142            })??
143            .await?;
144
145        // todo! test
146        let content = cx
147            .background_spawn(async move {
148                let start = request.byte_offset.unwrap_or(0) as usize;
149                let end = request
150                    .byte_limit
151                    .map(|limit| (start + limit as usize).min(file.content.len()))
152                    .unwrap_or(file.content.len());
153
154                let range_content = &file.content[start..end];
155
156                let mut base64_content = Vec::new();
157                let mut base64_encoder = base64::write::EncoderWriter::new(
158                    std::io::Cursor::new(&mut base64_content),
159                    &base64::engine::general_purpose::STANDARD,
160                );
161                base64_encoder.write_all(range_content)?;
162                drop(base64_encoder);
163
164                // SAFETY: The base64 encoder should not produce non-UTF8.
165                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
166            })
167            .await?;
168
169        Ok(acp::ReadBinaryFileResponse {
170            content,
171            // todo!
172            version: acp::FileVersion(0),
173        })
174    }
175
176    async fn glob_search(
177        &self,
178        _request: acp::GlobSearchParams,
179    ) -> Result<acp::GlobSearchResponse> {
180        todo!()
181    }
182
183    async fn request_tool_call_confirmation(
184        &self,
185        request: acp::RequestToolCallConfirmationParams,
186    ) -> Result<acp::RequestToolCallConfirmationResponse> {
187        let cx = &mut self.cx.clone();
188        let ToolCallRequest { id, outcome } = cx
189            .update(|cx| {
190                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
191                    thread.request_tool_call(request.label, request.icon, request.confirmation, cx)
192                })
193            })?
194            .context("Failed to update thread")?;
195
196        Ok(acp::RequestToolCallConfirmationResponse {
197            id: id.into(),
198            outcome: outcome.await?,
199        })
200    }
201
202    async fn push_tool_call(
203        &self,
204        request: acp::PushToolCallParams,
205    ) -> Result<acp::PushToolCallResponse> {
206        let cx = &mut self.cx.clone();
207        let entry_id = cx
208            .update(|cx| {
209                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
210                    thread.push_tool_call(request.label, request.icon, cx)
211                })
212            })?
213            .context("Failed to update thread")?;
214
215        Ok(acp::PushToolCallResponse {
216            id: entry_id.into(),
217        })
218    }
219
220    async fn update_tool_call(
221        &self,
222        request: acp::UpdateToolCallParams,
223    ) -> Result<acp::UpdateToolCallResponse> {
224        let cx = &mut self.cx.clone();
225
226        cx.update(|cx| {
227            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
228                thread.update_tool_call(
229                    request.tool_call_id.into(),
230                    request.status,
231                    request.content,
232                    cx,
233                )
234            })
235        })?
236        .context("Failed to update thread")??;
237
238        Ok(acp::UpdateToolCallResponse)
239    }
240}
241
242impl AcpServer {
243    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
244        let stdin = process.stdin.take().expect("process didn't have stdin");
245        let stdout = process.stdout.take().expect("process didn't have stdout");
246
247        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
248        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
249            AcpClientDelegate::new(project.clone(), threads.clone(), cx.to_async()),
250            stdin,
251            stdout,
252        );
253
254        let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
255        let io_task = cx.background_spawn({
256            let exit_status = exit_status.clone();
257            async move {
258                io_fut.await.log_err();
259                let result = process.status().await.log_err();
260                *exit_status.lock() = result;
261            }
262        });
263
264        Arc::new(Self {
265            project,
266            connection: Arc::new(connection),
267            threads,
268            exit_status,
269            _handler_task: cx.foreground_executor().spawn(handler_fut),
270            _io_task: io_task,
271        })
272    }
273
274    pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
275        self.connection
276            .request(acp::InitializeParams)
277            .await
278            .map_err(to_anyhow)
279    }
280
281    pub async fn authenticate(&self) -> Result<()> {
282        self.connection
283            .request(acp::AuthenticateParams)
284            .await
285            .map_err(to_anyhow)?;
286
287        Ok(())
288    }
289
290    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
291        let response = self
292            .connection
293            .request(acp::CreateThreadParams)
294            .await
295            .map_err(to_anyhow)?;
296
297        let thread_id: ThreadId = response.thread_id.into();
298        let server = self.clone();
299        let thread = cx.new(|_| AcpThread {
300            // todo!
301            title: "ACP Thread".into(),
302            id: thread_id.clone(), // Either<ErrorState, Id>
303            next_entry_id: ThreadEntryId(0),
304            entries: Vec::default(),
305            project: self.project.clone(),
306            server,
307        })?;
308        self.threads.lock().insert(thread_id, thread.downgrade());
309        Ok(thread)
310    }
311
312    pub async fn send_message(
313        &self,
314        thread_id: ThreadId,
315        message: acp::Message,
316        _cx: &mut AsyncApp,
317    ) -> Result<()> {
318        self.connection
319            .request(acp::SendMessageParams {
320                thread_id: thread_id.clone().into(),
321                message,
322            })
323            .await
324            .map_err(to_anyhow)?;
325        Ok(())
326    }
327
328    pub fn exit_status(&self) -> Option<ExitStatus> {
329        self.exit_status.lock().clone()
330    }
331}
332
333#[track_caller]
334fn to_anyhow(e: acp::Error) -> anyhow::Error {
335    log::error!(
336        "failed to send message: {code}: {message}",
337        code = e.code,
338        message = e.message
339    );
340    anyhow::anyhow!(e.message)
341}
342
343impl From<acp::ThreadId> for ThreadId {
344    fn from(thread_id: acp::ThreadId) -> Self {
345        Self(thread_id.0.into())
346    }
347}
348
349impl From<ThreadId> for acp::ThreadId {
350    fn from(thread_id: ThreadId) -> Self {
351        acp::ThreadId(thread_id.0.to_string())
352    }
353}
354
355impl From<acp::ToolCallId> for ToolCallId {
356    fn from(tool_call_id: acp::ToolCallId) -> Self {
357        Self(ThreadEntryId(tool_call_id.0))
358    }
359}
360
361impl From<ToolCallId> for acp::ToolCallId {
362    fn from(tool_call_id: ToolCallId) -> Self {
363        acp::ToolCallId(tool_call_id.as_u64())
364    }
365}