server.rs

  1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use futures::channel::oneshot;
  7use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  8use parking_lot::Mutex;
  9use project::Project;
 10use smol::process::Child;
 11use std::{io::Write as _, path::Path, sync::Arc};
 12use util::ResultExt;
 13
 14pub struct AcpServer {
 15    connection: Arc<acp::AgentConnection>,
 16    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 17    project: Entity<Project>,
 18    _handler_task: Task<()>,
 19    _io_task: Task<()>,
 20}
 21
 22struct AcpClientDelegate {
 23    project: Entity<Project>,
 24    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 25    cx: AsyncApp,
 26    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 27}
 28
 29impl AcpClientDelegate {
 30    fn new(
 31        project: Entity<Project>,
 32        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 33        cx: AsyncApp,
 34    ) -> Self {
 35        Self {
 36            project,
 37            threads,
 38            cx: cx,
 39        }
 40    }
 41
 42    fn update_thread<R>(
 43        &self,
 44        thread_id: &ThreadId,
 45        cx: &mut App,
 46        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
 47    ) -> Option<R> {
 48        let thread = self.threads.lock().get(&thread_id)?.clone();
 49        let Some(thread) = thread.upgrade() else {
 50            self.threads.lock().remove(&thread_id);
 51            return None;
 52        };
 53        Some(thread.update(cx, callback))
 54    }
 55}
 56
 57#[async_trait(?Send)]
 58impl acp::Client for AcpClientDelegate {
 59    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 60        let cx = &mut self.cx.clone();
 61        self.project.update(cx, |project, cx| {
 62            let path = project
 63                .project_path_for_absolute_path(Path::new(&params.path), cx)
 64                .context("Failed to get project path")?;
 65
 66            match project.entry_for_path(&path, cx) {
 67                // todo! refresh entry?
 68                None => Ok(acp::StatResponse {
 69                    exists: false,
 70                    is_directory: false,
 71                }),
 72                Some(entry) => Ok(acp::StatResponse {
 73                    exists: entry.is_created(),
 74                    is_directory: entry.is_dir(),
 75                }),
 76            }
 77        })?
 78    }
 79
 80    async fn stream_message_chunk(
 81        &self,
 82        params: acp::StreamMessageChunkParams,
 83    ) -> Result<acp::StreamMessageChunkResponse> {
 84        let cx = &mut self.cx.clone();
 85
 86        cx.update(|cx| {
 87            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 88                thread.push_assistant_chunk(params.chunk, cx)
 89            });
 90        })?;
 91
 92        Ok(acp::StreamMessageChunkResponse)
 93    }
 94
 95    async fn read_text_file(
 96        &self,
 97        request: acp::ReadTextFileParams,
 98    ) -> Result<acp::ReadTextFileResponse> {
 99        let cx = &mut self.cx.clone();
100        let buffer = self
101            .project
102            .update(cx, |project, cx| {
103                let path = project
104                    .project_path_for_absolute_path(Path::new(&request.path), cx)
105                    .context("Failed to get project path")?;
106                anyhow::Ok(project.open_buffer(path, cx))
107            })??
108            .await?;
109
110        buffer.update(cx, |buffer, _cx| {
111            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
112            let end = match request.line_limit {
113                None => buffer.max_point(),
114                Some(limit) => start + language::Point::new(limit + 1, 0),
115            };
116
117            let content: String = buffer.text_for_range(start..end).collect();
118
119            acp::ReadTextFileResponse {
120                content,
121                version: acp::FileVersion(0),
122            }
123        })
124    }
125
126    async fn read_binary_file(
127        &self,
128        request: acp::ReadBinaryFileParams,
129    ) -> Result<acp::ReadBinaryFileResponse> {
130        let cx = &mut self.cx.clone();
131        let file = self
132            .project
133            .update(cx, |project, cx| {
134                let (worktree, path) = project
135                    .find_worktree(Path::new(&request.path), cx)
136                    .context("Failed to get project path")?;
137
138                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
139                anyhow::Ok(task)
140            })??
141            .await?;
142
143        // todo! test
144        let content = cx
145            .background_spawn(async move {
146                let start = request.byte_offset.unwrap_or(0) as usize;
147                let end = request
148                    .byte_limit
149                    .map(|limit| (start + limit as usize).min(file.content.len()))
150                    .unwrap_or(file.content.len());
151
152                let range_content = &file.content[start..end];
153
154                let mut base64_content = Vec::new();
155                let mut base64_encoder = base64::write::EncoderWriter::new(
156                    std::io::Cursor::new(&mut base64_content),
157                    &base64::engine::general_purpose::STANDARD,
158                );
159                base64_encoder.write_all(range_content)?;
160                drop(base64_encoder);
161
162                // SAFETY: The base64 encoder should not produce non-UTF8.
163                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
164            })
165            .await?;
166
167        Ok(acp::ReadBinaryFileResponse {
168            content,
169            // todo!
170            version: acp::FileVersion(0),
171        })
172    }
173
174    async fn glob_search(
175        &self,
176        _request: acp::GlobSearchParams,
177    ) -> Result<acp::GlobSearchResponse> {
178        todo!()
179    }
180
181    async fn request_tool_call_confirmation(
182        &self,
183        request: acp::RequestToolCallConfirmationParams,
184    ) -> Result<acp::RequestToolCallConfirmationResponse> {
185        let (tx, rx) = oneshot::channel();
186
187        let cx = &mut self.cx.clone();
188        let entry_id = cx
189            .update(|cx| {
190                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
191                    // todo! Should we pass through richer data than a description?
192                    let description = match request.confirmation {
193                        acp::ToolCallConfirmation::Edit {
194                            file_name,
195                            file_diff,
196                        } => {
197                            // todo! Nicer syntax/presentation based on file extension? Better way to communicate diff?
198                            format!("Edit file `{file_name}` with diff:\n```\n{file_diff}\n```")
199                        }
200                        acp::ToolCallConfirmation::Execute {
201                            command,
202                            root_command: _,
203                        } => {
204                            format!("Execute command `{command}`")
205                        }
206                        acp::ToolCallConfirmation::Mcp {
207                            server_name,
208                            tool_name: _,
209                            tool_display_name,
210                        } => {
211                            format!("MCP: {server_name} - {tool_display_name}")
212                        }
213                        acp::ToolCallConfirmation::Info { prompt, urls } => {
214                            format!("Info: {prompt}\n{urls:?}")
215                        }
216                    };
217                    thread.push_tool_call(request.title, description, Some(tx), cx)
218                })
219            })?
220            .context("Failed to update thread")?;
221
222        let outcome = if rx.await? {
223            // todo! Handle other outcomes
224            acp::ToolCallConfirmationOutcome::Allow
225        } else {
226            acp::ToolCallConfirmationOutcome::Reject
227        };
228        Ok(acp::RequestToolCallConfirmationResponse {
229            id: entry_id.into(),
230            outcome,
231        })
232    }
233
234    async fn push_tool_call(
235        &self,
236        request: acp::PushToolCallParams,
237    ) -> Result<acp::PushToolCallResponse> {
238        let cx = &mut self.cx.clone();
239        let entry_id = cx
240            .update(|cx| {
241                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
242                    thread.push_tool_call(request.title, request.description, None, cx)
243                })
244            })?
245            .context("Failed to update thread")?;
246
247        Ok(acp::PushToolCallResponse {
248            id: entry_id.into(),
249        })
250    }
251
252    async fn update_tool_call(
253        &self,
254        request: acp::UpdateToolCallParams,
255    ) -> Result<acp::UpdateToolCallResponse> {
256        let cx = &mut self.cx.clone();
257
258        cx.update(|cx| {
259            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
260                thread.update_tool_call(
261                    request.tool_call_id.into(),
262                    request.status,
263                    request.content,
264                    cx,
265                )
266            })
267        })?
268        .context("Failed to update thread")??;
269
270        Ok(acp::UpdateToolCallResponse)
271    }
272}
273
274impl AcpServer {
275    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
276        let stdin = process.stdin.take().expect("process didn't have stdin");
277        let stdout = process.stdout.take().expect("process didn't have stdout");
278
279        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
280        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
281            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
282            stdin,
283            stdout,
284        );
285
286        let io_task = cx.background_spawn(async move {
287            io_fut.await.log_err();
288            process.status().await.log_err();
289        });
290
291        Arc::new(Self {
292            project,
293            connection: Arc::new(connection),
294            threads,
295            _handler_task: cx.foreground_executor().spawn(handler_fut),
296            _io_task: io_task,
297        })
298    }
299}
300
301impl AcpServer {
302    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
303        let response = self.connection.request(acp::CreateThreadParams).await?;
304        let thread_id: ThreadId = response.thread_id.into();
305        let server = self.clone();
306        let thread = cx.new(|_| AcpThread {
307            // todo!
308            title: "ACP Thread".into(),
309            id: thread_id.clone(),
310            next_entry_id: ThreadEntryId(0),
311            entries: Vec::default(),
312            project: self.project.clone(),
313            server,
314        })?;
315        self.threads.lock().insert(thread_id, thread.downgrade());
316        Ok(thread)
317    }
318
319    pub async fn send_message(
320        &self,
321        thread_id: ThreadId,
322        message: acp::Message,
323        _cx: &mut AsyncApp,
324    ) -> Result<()> {
325        self.connection
326            .request(acp::SendMessageParams {
327                thread_id: thread_id.clone().into(),
328                message,
329            })
330            .await?;
331        Ok(())
332    }
333}
334
335impl From<acp::ThreadId> for ThreadId {
336    fn from(thread_id: acp::ThreadId) -> Self {
337        Self(thread_id.0.into())
338    }
339}
340
341impl From<ThreadId> for acp::ThreadId {
342    fn from(thread_id: ThreadId) -> Self {
343        acp::ThreadId(thread_id.0.to_string())
344    }
345}
346
347impl From<acp::ToolCallId> for ToolCallId {
348    fn from(tool_call_id: acp::ToolCallId) -> Self {
349        Self(ThreadEntryId(tool_call_id.0))
350    }
351}
352
353impl From<ToolCallId> for acp::ToolCallId {
354    fn from(tool_call_id: ToolCallId) -> Self {
355        acp::ToolCallId(tool_call_id.as_u64())
356    }
357}