server.rs

  1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  7use parking_lot::Mutex;
  8use project::Project;
  9use smol::process::Child;
 10use std::{io::Write as _, path::Path, sync::Arc};
 11use util::ResultExt;
 12
 13pub struct AcpServer {
 14    connection: Arc<acp::AgentConnection>,
 15    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 16    project: Entity<Project>,
 17    _handler_task: Task<()>,
 18    _io_task: Task<()>,
 19}
 20
 21struct AcpClientDelegate {
 22    project: Entity<Project>,
 23    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 24    cx: AsyncApp,
 25    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 26}
 27
 28impl AcpClientDelegate {
 29    fn new(
 30        project: Entity<Project>,
 31        threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 32        cx: AsyncApp,
 33    ) -> Self {
 34        Self {
 35            project,
 36            threads,
 37            cx: cx,
 38        }
 39    }
 40
 41    fn update_thread<R>(
 42        &self,
 43        thread_id: &ThreadId,
 44        cx: &mut App,
 45        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
 46    ) -> Option<R> {
 47        let thread = self.threads.lock().get(&thread_id)?.clone();
 48        let Some(thread) = thread.upgrade() else {
 49            self.threads.lock().remove(&thread_id);
 50            return None;
 51        };
 52        Some(thread.update(cx, callback))
 53    }
 54}
 55
 56#[async_trait(?Send)]
 57impl acp::Client for AcpClientDelegate {
 58    async fn stat(&self, params: acp::StatParams) -> Result<acp::StatResponse> {
 59        let cx = &mut self.cx.clone();
 60        self.project.update(cx, |project, cx| {
 61            let path = project
 62                .project_path_for_absolute_path(Path::new(&params.path), cx)
 63                .context("Failed to get project path")?;
 64
 65            match project.entry_for_path(&path, cx) {
 66                // todo! refresh entry?
 67                None => Ok(acp::StatResponse {
 68                    exists: false,
 69                    is_directory: false,
 70                }),
 71                Some(entry) => Ok(acp::StatResponse {
 72                    exists: entry.is_created(),
 73                    is_directory: entry.is_dir(),
 74                }),
 75            }
 76        })?
 77    }
 78
 79    async fn stream_message_chunk(
 80        &self,
 81        params: acp::StreamMessageChunkParams,
 82    ) -> Result<acp::StreamMessageChunkResponse> {
 83        let cx = &mut self.cx.clone();
 84
 85        cx.update(|cx| {
 86            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 87                thread.push_assistant_chunk(params.chunk, cx)
 88            });
 89        })?;
 90
 91        Ok(acp::StreamMessageChunkResponse)
 92    }
 93
 94    async fn read_text_file(
 95        &self,
 96        request: acp::ReadTextFileParams,
 97    ) -> Result<acp::ReadTextFileResponse> {
 98        let cx = &mut self.cx.clone();
 99        let buffer = self
100            .project
101            .update(cx, |project, cx| {
102                let path = project
103                    .project_path_for_absolute_path(Path::new(&request.path), cx)
104                    .context("Failed to get project path")?;
105                anyhow::Ok(project.open_buffer(path, cx))
106            })??
107            .await?;
108
109        buffer.update(cx, |buffer, _cx| {
110            let start = language::Point::new(request.line_offset.unwrap_or(0), 0);
111            let end = match request.line_limit {
112                None => buffer.max_point(),
113                Some(limit) => start + language::Point::new(limit + 1, 0),
114            };
115
116            let content: String = buffer.text_for_range(start..end).collect();
117
118            acp::ReadTextFileResponse {
119                content,
120                version: acp::FileVersion(0),
121            }
122        })
123    }
124
125    async fn read_binary_file(
126        &self,
127        request: acp::ReadBinaryFileParams,
128    ) -> Result<acp::ReadBinaryFileResponse> {
129        let cx = &mut self.cx.clone();
130        let file = self
131            .project
132            .update(cx, |project, cx| {
133                let (worktree, path) = project
134                    .find_worktree(Path::new(&request.path), cx)
135                    .context("Failed to get project path")?;
136
137                let task = worktree.update(cx, |worktree, cx| worktree.load_binary_file(&path, cx));
138                anyhow::Ok(task)
139            })??
140            .await?;
141
142        // todo! test
143        let content = cx
144            .background_spawn(async move {
145                let start = request.byte_offset.unwrap_or(0) as usize;
146                let end = request
147                    .byte_limit
148                    .map(|limit| (start + limit as usize).min(file.content.len()))
149                    .unwrap_or(file.content.len());
150
151                let range_content = &file.content[start..end];
152
153                let mut base64_content = Vec::new();
154                let mut base64_encoder = base64::write::EncoderWriter::new(
155                    std::io::Cursor::new(&mut base64_content),
156                    &base64::engine::general_purpose::STANDARD,
157                );
158                base64_encoder.write_all(range_content)?;
159                drop(base64_encoder);
160
161                // SAFETY: The base64 encoder should not produce non-UTF8.
162                unsafe { anyhow::Ok(String::from_utf8_unchecked(base64_content)) }
163            })
164            .await?;
165
166        Ok(acp::ReadBinaryFileResponse {
167            content,
168            // todo!
169            version: acp::FileVersion(0),
170        })
171    }
172
173    async fn glob_search(
174        &self,
175        _request: acp::GlobSearchParams,
176    ) -> Result<acp::GlobSearchResponse> {
177        todo!()
178    }
179
180    async fn request_tool_call_confirmation(
181        &self,
182        request: acp::RequestToolCallConfirmationParams,
183    ) -> Result<acp::RequestToolCallConfirmationResponse> {
184        let cx = &mut self.cx.clone();
185        let ToolCallRequest { id, outcome } = cx
186            .update(|cx| {
187                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
188                    thread.request_tool_call(request.display_name, request.confirmation, cx)
189                })
190            })?
191            .context("Failed to update thread")?;
192
193        Ok(acp::RequestToolCallConfirmationResponse {
194            id: id.into(),
195            outcome: outcome.await?,
196        })
197    }
198
199    async fn push_tool_call(
200        &self,
201        request: acp::PushToolCallParams,
202    ) -> Result<acp::PushToolCallResponse> {
203        let cx = &mut self.cx.clone();
204        let entry_id = cx
205            .update(|cx| {
206                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
207                    thread.push_tool_call(request.display_name, cx)
208                })
209            })?
210            .context("Failed to update thread")?;
211
212        Ok(acp::PushToolCallResponse {
213            id: entry_id.into(),
214        })
215    }
216
217    async fn update_tool_call(
218        &self,
219        request: acp::UpdateToolCallParams,
220    ) -> Result<acp::UpdateToolCallResponse> {
221        let cx = &mut self.cx.clone();
222
223        cx.update(|cx| {
224            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
225                thread.update_tool_call(
226                    request.tool_call_id.into(),
227                    request.status,
228                    request.content,
229                    cx,
230                )
231            })
232        })?
233        .context("Failed to update thread")??;
234
235        Ok(acp::UpdateToolCallResponse)
236    }
237}
238
239impl AcpServer {
240    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut AsyncApp) -> Arc<Self> {
241        let stdin = process.stdin.take().expect("process didn't have stdin");
242        let stdout = process.stdout.take().expect("process didn't have stdout");
243
244        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
245        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
246            AcpClientDelegate::new(project.clone(), threads.clone(), cx.clone()),
247            stdin,
248            stdout,
249        );
250
251        let io_task = cx.background_spawn(async move {
252            io_fut.await.log_err();
253            process.status().await.log_err();
254        });
255
256        Arc::new(Self {
257            project,
258            connection: Arc::new(connection),
259            threads,
260            _handler_task: cx.foreground_executor().spawn(handler_fut),
261            _io_task: io_task,
262        })
263    }
264}
265
266impl AcpServer {
267    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
268        let response = self.connection.request(acp::CreateThreadParams).await?;
269        let thread_id: ThreadId = response.thread_id.into();
270        let server = self.clone();
271        let thread = cx.new(|_| AcpThread {
272            // todo!
273            title: "ACP Thread".into(),
274            id: thread_id.clone(),
275            next_entry_id: ThreadEntryId(0),
276            entries: Vec::default(),
277            project: self.project.clone(),
278            server,
279        })?;
280        self.threads.lock().insert(thread_id, thread.downgrade());
281        Ok(thread)
282    }
283
284    pub async fn send_message(
285        &self,
286        thread_id: ThreadId,
287        message: acp::Message,
288        _cx: &mut AsyncApp,
289    ) -> Result<()> {
290        self.connection
291            .request(acp::SendMessageParams {
292                thread_id: thread_id.clone().into(),
293                message,
294            })
295            .await?;
296        Ok(())
297    }
298}
299
300impl From<acp::ThreadId> for ThreadId {
301    fn from(thread_id: acp::ThreadId) -> Self {
302        Self(thread_id.0.into())
303    }
304}
305
306impl From<ThreadId> for acp::ThreadId {
307    fn from(thread_id: ThreadId) -> Self {
308        acp::ThreadId(thread_id.0.to_string())
309    }
310}
311
312impl From<acp::ToolCallId> for ToolCallId {
313    fn from(tool_call_id: acp::ToolCallId) -> Self {
314        Self(ThreadEntryId(tool_call_id.0))
315    }
316}
317
318impl From<ToolCallId> for acp::ToolCallId {
319    fn from(tool_call_id: ToolCallId) -> Self {
320        acp::ToolCallId(tool_call_id.as_u64())
321    }
322}