server.rs

  1use crate::{AcpThread, ThreadEntryId, ThreadId, ToolCallId, ToolCallRequest};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use collections::HashMap;
  6use gpui::{App, AppContext, AsyncApp, Context, Entity, Task, WeakEntity};
  7use parking_lot::Mutex;
  8use project::Project;
  9use smol::process::Child;
 10use std::{process::ExitStatus, sync::Arc};
 11use util::ResultExt;
 12
 13pub struct AcpServer {
 14    connection: Arc<acp::AgentConnection>,
 15    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 16    project: Entity<Project>,
 17    exit_status: Arc<Mutex<Option<ExitStatus>>>,
 18    _handler_task: Task<()>,
 19    _io_task: Task<()>,
 20}
 21
 22struct AcpClientDelegate {
 23    threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>,
 24    cx: AsyncApp,
 25    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 26}
 27
 28impl AcpClientDelegate {
 29    fn new(threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>>, cx: AsyncApp) -> Self {
 30        Self { threads, cx: cx }
 31    }
 32
 33    fn update_thread<R>(
 34        &self,
 35        thread_id: &ThreadId,
 36        cx: &mut App,
 37        callback: impl FnOnce(&mut AcpThread, &mut Context<AcpThread>) -> R,
 38    ) -> Option<R> {
 39        let thread = self.threads.lock().get(&thread_id)?.clone();
 40        let Some(thread) = thread.upgrade() else {
 41            self.threads.lock().remove(&thread_id);
 42            return None;
 43        };
 44        Some(thread.update(cx, callback))
 45    }
 46}
 47
 48#[async_trait(?Send)]
 49impl acp::Client for AcpClientDelegate {
 50    async fn stream_assistant_message_chunk(
 51        &self,
 52        params: acp::StreamAssistantMessageChunkParams,
 53    ) -> Result<acp::StreamAssistantMessageChunkResponse> {
 54        let cx = &mut self.cx.clone();
 55
 56        cx.update(|cx| {
 57            self.update_thread(&params.thread_id.into(), cx, |thread, cx| {
 58                thread.push_assistant_chunk(params.chunk, cx)
 59            });
 60        })?;
 61
 62        Ok(acp::StreamAssistantMessageChunkResponse)
 63    }
 64
 65    async fn request_tool_call_confirmation(
 66        &self,
 67        request: acp::RequestToolCallConfirmationParams,
 68    ) -> Result<acp::RequestToolCallConfirmationResponse> {
 69        let cx = &mut self.cx.clone();
 70        let ToolCallRequest { id, outcome } = cx
 71            .update(|cx| {
 72                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
 73                    thread.request_tool_call(
 74                        request.label,
 75                        request.icon,
 76                        request.content,
 77                        request.confirmation,
 78                        cx,
 79                    )
 80                })
 81            })?
 82            .context("Failed to update thread")?;
 83
 84        Ok(acp::RequestToolCallConfirmationResponse {
 85            id: id.into(),
 86            outcome: outcome.await?,
 87        })
 88    }
 89
 90    async fn push_tool_call(
 91        &self,
 92        request: acp::PushToolCallParams,
 93    ) -> Result<acp::PushToolCallResponse> {
 94        let cx = &mut self.cx.clone();
 95        let entry_id = cx
 96            .update(|cx| {
 97                self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
 98                    thread.push_tool_call(request.label, request.icon, request.content, cx)
 99                })
100            })?
101            .context("Failed to update thread")?;
102
103        Ok(acp::PushToolCallResponse {
104            id: entry_id.into(),
105        })
106    }
107
108    async fn update_tool_call(
109        &self,
110        request: acp::UpdateToolCallParams,
111    ) -> Result<acp::UpdateToolCallResponse> {
112        let cx = &mut self.cx.clone();
113
114        cx.update(|cx| {
115            self.update_thread(&request.thread_id.into(), cx, |thread, cx| {
116                thread.update_tool_call(
117                    request.tool_call_id.into(),
118                    request.status,
119                    request.content,
120                    cx,
121                )
122            })
123        })?
124        .context("Failed to update thread")??;
125
126        Ok(acp::UpdateToolCallResponse)
127    }
128}
129
130impl AcpServer {
131    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
132        let stdin = process.stdin.take().expect("process didn't have stdin");
133        let stdout = process.stdout.take().expect("process didn't have stdout");
134
135        let threads: Arc<Mutex<HashMap<ThreadId, WeakEntity<AcpThread>>>> = Default::default();
136        let (connection, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
137            AcpClientDelegate::new(threads.clone(), cx.to_async()),
138            stdin,
139            stdout,
140        );
141
142        let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
143        let io_task = cx.background_spawn({
144            let exit_status = exit_status.clone();
145            async move {
146                io_fut.await.log_err();
147                let result = process.status().await.log_err();
148                *exit_status.lock() = result;
149            }
150        });
151
152        Arc::new(Self {
153            project,
154            connection: Arc::new(connection),
155            threads,
156            exit_status,
157            _handler_task: cx.foreground_executor().spawn(handler_fut),
158            _io_task: io_task,
159        })
160    }
161
162    pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
163        self.connection
164            .request(acp::InitializeParams)
165            .await
166            .map_err(to_anyhow)
167    }
168
169    pub async fn authenticate(&self) -> Result<()> {
170        self.connection
171            .request(acp::AuthenticateParams)
172            .await
173            .map_err(to_anyhow)?;
174
175        Ok(())
176    }
177
178    pub async fn create_thread(self: Arc<Self>, cx: &mut AsyncApp) -> Result<Entity<AcpThread>> {
179        let response = self
180            .connection
181            .request(acp::CreateThreadParams)
182            .await
183            .map_err(to_anyhow)?;
184
185        let thread_id: ThreadId = response.thread_id.into();
186        let server = self.clone();
187        let thread = cx.new(|cx| {
188            AcpThread::new(
189                server,
190                thread_id.clone(),
191                Vec::default(),
192                self.project.clone(),
193                cx,
194            )
195        })?;
196        self.threads.lock().insert(thread_id, thread.downgrade());
197        Ok(thread)
198    }
199
200    pub async fn send_message(
201        &self,
202        thread_id: ThreadId,
203        message: acp::UserMessage,
204        _cx: &mut AsyncApp,
205    ) -> Result<()> {
206        self.connection
207            .request(acp::SendUserMessageParams {
208                thread_id: thread_id.clone().into(),
209                message,
210            })
211            .await
212            .map_err(to_anyhow)?;
213        Ok(())
214    }
215
216    pub async fn cancel_send_message(&self, thread_id: ThreadId, _cx: &mut AsyncApp) -> Result<()> {
217        self.connection
218            .request(acp::CancelSendMessageParams {
219                thread_id: thread_id.clone().into(),
220            })
221            .await
222            .map_err(to_anyhow)?;
223        Ok(())
224    }
225
226    pub fn exit_status(&self) -> Option<ExitStatus> {
227        *self.exit_status.lock()
228    }
229}
230
231#[track_caller]
232fn to_anyhow(e: acp::Error) -> anyhow::Error {
233    log::error!(
234        "failed to send message: {code}: {message}",
235        code = e.code,
236        message = e.message
237    );
238    anyhow::anyhow!(e.message)
239}
240
241impl From<acp::ThreadId> for ThreadId {
242    fn from(thread_id: acp::ThreadId) -> Self {
243        Self(thread_id.0.into())
244    }
245}
246
247impl From<ThreadId> for acp::ThreadId {
248    fn from(thread_id: ThreadId) -> Self {
249        acp::ThreadId(thread_id.0.to_string())
250    }
251}
252
253impl From<acp::ToolCallId> for ToolCallId {
254    fn from(tool_call_id: acp::ToolCallId) -> Self {
255        Self(ThreadEntryId(tool_call_id.0))
256    }
257}
258
259impl From<ToolCallId> for acp::ToolCallId {
260    fn from(tool_call_id: ToolCallId) -> Self {
261        acp::ToolCallId(tool_call_id.as_u64())
262    }
263}