server.rs

  1use crate::{AcpThread, ThreadEntryId, ToolCallId, ToolCallRequest};
  2use agentic_coding_protocol as acp;
  3use anyhow::{Context as _, Result};
  4use async_trait::async_trait;
  5use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  6use parking_lot::Mutex;
  7use project::Project;
  8use smol::process::Child;
  9use std::{process::ExitStatus, sync::Arc};
 10use util::ResultExt;
 11
 12pub struct AcpServer {
 13    thread: WeakEntity<AcpThread>,
 14    project: Entity<Project>,
 15    connection: Arc<acp::AgentConnection>,
 16    exit_status: Arc<Mutex<Option<ExitStatus>>>,
 17    _handler_task: Task<()>,
 18    _io_task: Task<()>,
 19}
 20
 21struct AcpClientDelegate {
 22    thread: WeakEntity<AcpThread>,
 23    cx: AsyncApp,
 24    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
 25}
 26
 27impl AcpClientDelegate {
 28    fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
 29        Self { thread, cx }
 30    }
 31}
 32
 33#[async_trait(?Send)]
 34impl acp::Client for AcpClientDelegate {
 35    async fn stream_assistant_message_chunk(
 36        &self,
 37        params: acp::StreamAssistantMessageChunkParams,
 38    ) -> Result<acp::StreamAssistantMessageChunkResponse> {
 39        let cx = &mut self.cx.clone();
 40
 41        cx.update(|cx| {
 42            self.thread.update(cx, |thread, cx| {
 43                thread.push_assistant_chunk(params.chunk, cx)
 44            });
 45        })?;
 46
 47        Ok(acp::StreamAssistantMessageChunkResponse)
 48    }
 49
 50    async fn request_tool_call_confirmation(
 51        &self,
 52        request: acp::RequestToolCallConfirmationParams,
 53    ) -> Result<acp::RequestToolCallConfirmationResponse> {
 54        let cx = &mut self.cx.clone();
 55        let ToolCallRequest { id, outcome } = cx
 56            .update(|cx| {
 57                self.thread.update(cx, |thread, cx| {
 58                    thread.request_tool_call(
 59                        request.label,
 60                        request.icon,
 61                        request.content,
 62                        request.confirmation,
 63                        cx,
 64                    )
 65                })
 66            })?
 67            .context("Failed to update thread")?;
 68
 69        Ok(acp::RequestToolCallConfirmationResponse {
 70            id: id.into(),
 71            outcome: outcome.await?,
 72        })
 73    }
 74
 75    async fn push_tool_call(
 76        &self,
 77        request: acp::PushToolCallParams,
 78    ) -> Result<acp::PushToolCallResponse> {
 79        let cx = &mut self.cx.clone();
 80        let entry_id = cx
 81            .update(|cx| {
 82                self.thread.update(cx, |thread, cx| {
 83                    thread.push_tool_call(request.label, request.icon, request.content, cx)
 84                })
 85            })?
 86            .context("Failed to update thread")?;
 87
 88        Ok(acp::PushToolCallResponse {
 89            id: entry_id.into(),
 90        })
 91    }
 92
 93    async fn update_tool_call(
 94        &self,
 95        request: acp::UpdateToolCallParams,
 96    ) -> Result<acp::UpdateToolCallResponse> {
 97        let cx = &mut self.cx.clone();
 98
 99        cx.update(|cx| {
100            self.thread.update(cx, |thread, cx| {
101                thread.update_tool_call(
102                    request.tool_call_id.into(),
103                    request.status,
104                    request.content,
105                    cx,
106                )
107            })
108        })?
109        .context("Failed to update thread")??;
110
111        Ok(acp::UpdateToolCallResponse)
112    }
113}
114
115impl AcpServer {
116    pub fn stdio(mut process: Child, project: Entity<Project>, cx: &mut App) -> Arc<Self> {
117        let stdin = process.stdin.take().expect("process didn't have stdin");
118        let stdout = process.stdout.take().expect("process didn't have stdout");
119
120        let mut connection = None;
121        cx.new(|cx| {
122            let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
123                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
124                stdin,
125                stdout,
126            );
127
128            let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
129            let io_task = cx.background_spawn({
130                let exit_status = exit_status.clone();
131                async move {
132                    io_fut.await.log_err();
133                    let result = process.status().await.log_err();
134                    *exit_status.lock() = result;
135                }
136            });
137
138            connection.replace(Arc::new(Self {
139                project: project.clone(),
140                connection: Arc::new(conn),
141                thread: cx.entity().downgrade(),
142                exit_status,
143                _handler_task: cx.foreground_executor().spawn(handler_fut),
144                _io_task: io_task,
145            }));
146
147            AcpThread::new(
148                connection.clone().unwrap(),
149                Vec::default(),
150                project.clone(),
151                cx,
152            )
153        });
154
155        connection.unwrap()
156    }
157
158    #[cfg(test)]
159    pub fn fake(
160        stdin: async_pipe::PipeWriter,
161        stdout: async_pipe::PipeReader,
162        project: Entity<Project>,
163        cx: &mut App,
164    ) -> Arc<Self> {
165        let mut connection = None;
166        cx.new(|cx| {
167            let (conn, handler_fut, io_fut) = acp::AgentConnection::connect_to_agent(
168                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
169                stdin,
170                stdout,
171            );
172
173            let exit_status: Arc<Mutex<Option<ExitStatus>>> = Default::default();
174            let io_task = cx.background_spawn({
175                async move {
176                    io_fut.await.log_err();
177                    // todo!() exit status?
178                }
179            });
180
181            connection.replace(Arc::new(Self {
182                project: project.clone(),
183                connection: Arc::new(conn),
184                thread: cx.entity().downgrade(),
185                exit_status,
186                _handler_task: cx.foreground_executor().spawn(handler_fut),
187                _io_task: io_task,
188            }));
189
190            AcpThread::new(
191                connection.clone().unwrap(),
192                Vec::default(),
193                project.clone(),
194                cx,
195            )
196        });
197
198        connection.unwrap()
199    }
200
201    pub async fn initialize(&self) -> Result<acp::InitializeResponse> {
202        self.connection
203            .request(acp::InitializeParams)
204            .await
205            .map_err(to_anyhow)
206    }
207
208    pub async fn authenticate(&self) -> Result<()> {
209        self.connection
210            .request(acp::AuthenticateParams)
211            .await
212            .map_err(to_anyhow)?;
213
214        Ok(())
215    }
216
217    pub async fn send_message(&self, message: acp::UserMessage, _cx: &mut AsyncApp) -> Result<()> {
218        self.connection
219            .request(acp::SendUserMessageParams { message })
220            .await
221            .map_err(to_anyhow)?;
222        Ok(())
223    }
224
225    pub async fn cancel_send_message(&self, _cx: &mut AsyncApp) -> Result<()> {
226        self.connection
227            .request(acp::CancelSendMessageParams)
228            .await
229            .map_err(to_anyhow)?;
230        Ok(())
231    }
232
233    pub fn exit_status(&self) -> Option<ExitStatus> {
234        *self.exit_status.lock()
235    }
236}
237
238#[track_caller]
239fn to_anyhow(e: acp::Error) -> anyhow::Error {
240    log::error!(
241        "failed to send message: {code}: {message}",
242        code = e.code,
243        message = e.message
244    );
245    anyhow::anyhow!(e.message)
246}
247
248impl From<acp::ToolCallId> for ToolCallId {
249    fn from(tool_call_id: acp::ToolCallId) -> Self {
250        Self(ThreadEntryId(tool_call_id.0))
251    }
252}
253
254impl From<ToolCallId> for acp::ToolCallId {
255    fn from(tool_call_id: ToolCallId) -> Self {
256        acp::ToolCallId(tool_call_id.as_u64())
257    }
258}