acp_connection.rs

  1use agent_client_protocol as acp;
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use context_server::listener::McpServerTool;
  5use context_server::types::requests;
  6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  7use futures::channel::{mpsc, oneshot};
  8use project::Project;
  9use smol::stream::StreamExt as _;
 10use std::cell::RefCell;
 11use std::rc::Rc;
 12use std::{path::Path, sync::Arc};
 13use util::ResultExt;
 14
 15use anyhow::{Context, Result};
 16use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 17
 18use crate::mcp_server::ZedMcpServer;
 19use crate::{AgentServerCommand, mcp_server};
 20use acp_thread::{AcpThread, AgentConnection};
 21
 22pub struct AcpConnection {
 23    server_name: &'static str,
 24    client: Arc<context_server::ContextServer>,
 25    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 26    _notification_handler_task: Task<()>,
 27}
 28
 29impl AcpConnection {
 30    pub async fn stdio(
 31        server_name: &'static str,
 32        command: AgentServerCommand,
 33        cx: &mut AsyncApp,
 34    ) -> Result<Self> {
 35        let client: Arc<ContextServer> = ContextServer::stdio(
 36            ContextServerId(format!("{}-mcp-server", server_name).into()),
 37            ContextServerCommand {
 38                path: command.path,
 39                args: command.args,
 40                env: command.env,
 41            },
 42        )
 43        .into();
 44        ContextServer::start(client.clone(), cx).await?;
 45
 46        let (notification_tx, mut notification_rx) = mpsc::unbounded();
 47        client
 48            .client()
 49            .context("Failed to subscribe")?
 50            .on_notification(acp::AGENT_METHODS.session_update, {
 51                move |notification, _cx| {
 52                    let notification_tx = notification_tx.clone();
 53                    log::trace!(
 54                        "ACP Notification: {}",
 55                        serde_json::to_string_pretty(&notification).unwrap()
 56                    );
 57
 58                    if let Some(notification) =
 59                        serde_json::from_value::<acp::SessionNotification>(notification).log_err()
 60                    {
 61                        notification_tx.unbounded_send(notification).ok();
 62                    }
 63                }
 64            });
 65
 66        let sessions = Rc::new(RefCell::new(HashMap::default()));
 67
 68        let notification_handler_task = cx.spawn({
 69            let sessions = sessions.clone();
 70            async move |cx| {
 71                while let Some(notification) = notification_rx.next().await {
 72                    Self::handle_session_notification(notification, sessions.clone(), cx)
 73                }
 74            }
 75        });
 76
 77        Ok(Self {
 78            server_name,
 79            client,
 80            sessions,
 81            _notification_handler_task: notification_handler_task,
 82        })
 83    }
 84
 85    pub fn handle_session_notification(
 86        notification: acp::SessionNotification,
 87        threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 88        cx: &mut AsyncApp,
 89    ) {
 90        let threads = threads.borrow();
 91        let Some(thread) = threads
 92            .get(&notification.session_id)
 93            .and_then(|session| session.thread.upgrade())
 94        else {
 95            log::error!(
 96                "Thread not found for session ID: {}",
 97                notification.session_id
 98            );
 99            return;
100        };
101
102        thread
103            .update(cx, |thread, cx| {
104                thread.handle_session_update(notification.update, cx)
105            })
106            .log_err();
107    }
108}
109
110pub struct AcpSession {
111    thread: WeakEntity<AcpThread>,
112    cancel_tx: Option<oneshot::Sender<()>>,
113    _mcp_server: ZedMcpServer,
114}
115
116impl AgentConnection for AcpConnection {
117    fn new_thread(
118        self: Rc<Self>,
119        project: Entity<Project>,
120        cwd: &Path,
121        cx: &mut AsyncApp,
122    ) -> Task<Result<Entity<AcpThread>>> {
123        let client = self.client.client();
124        let sessions = self.sessions.clone();
125        let cwd = cwd.to_path_buf();
126        cx.spawn(async move |cx| {
127            let client = client.context("MCP server is not initialized yet")?;
128            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
129
130            let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
131
132            let response = client
133                .request::<requests::CallTool>(context_server::types::CallToolParams {
134                    name: acp::AGENT_METHODS.new_session.into(),
135                    arguments: Some(serde_json::to_value(acp::NewSessionArguments {
136                        mcp_servers: vec![mcp_server.server_config()?],
137                        client_tools: acp::ClientTools {
138                            request_permission: Some(acp::McpToolId {
139                                mcp_server: mcp_server::SERVER_NAME.into(),
140                                tool_name: mcp_server::RequestPermissionTool::NAME.into(),
141                            }),
142                            read_text_file: Some(acp::McpToolId {
143                                mcp_server: mcp_server::SERVER_NAME.into(),
144                                tool_name: mcp_server::ReadTextFileTool::NAME.into(),
145                            }),
146                            write_text_file: Some(acp::McpToolId {
147                                mcp_server: mcp_server::SERVER_NAME.into(),
148                                tool_name: mcp_server::WriteTextFileTool::NAME.into(),
149                            }),
150                        },
151                        cwd,
152                    })?),
153                    meta: None,
154                })
155                .await?;
156
157            if response.is_error.unwrap_or_default() {
158                return Err(anyhow!(response.text_contents()));
159            }
160
161            let result = serde_json::from_value::<acp::NewSessionOutput>(
162                response.structured_content.context("Empty response")?,
163            )?;
164
165            let thread = cx.new(|cx| {
166                AcpThread::new(
167                    self.server_name,
168                    self.clone(),
169                    project,
170                    result.session_id.clone(),
171                    cx,
172                )
173            })?;
174
175            thread_tx.send(thread.downgrade())?;
176
177            let session = AcpSession {
178                thread: thread.downgrade(),
179                cancel_tx: None,
180                _mcp_server: mcp_server,
181            };
182            sessions.borrow_mut().insert(result.session_id, session);
183
184            Ok(thread)
185        })
186    }
187
188    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
189        Task::ready(Err(anyhow!("Authentication not supported")))
190    }
191
192    fn prompt(
193        &self,
194        params: agent_client_protocol::PromptArguments,
195        cx: &mut App,
196    ) -> Task<Result<()>> {
197        let client = self.client.client();
198        let sessions = self.sessions.clone();
199
200        cx.foreground_executor().spawn(async move {
201            let client = client.context("MCP server is not initialized yet")?;
202
203            let (new_cancel_tx, cancel_rx) = oneshot::channel();
204            {
205                let mut sessions = sessions.borrow_mut();
206                let session = sessions
207                    .get_mut(&params.session_id)
208                    .context("Session not found")?;
209                session.cancel_tx.replace(new_cancel_tx);
210            }
211
212            let result = client
213                .request_with::<requests::CallTool>(
214                    context_server::types::CallToolParams {
215                        name: acp::AGENT_METHODS.prompt.into(),
216                        arguments: Some(serde_json::to_value(params)?),
217                        meta: None,
218                    },
219                    Some(cancel_rx),
220                    None,
221                )
222                .await;
223
224            if let Err(err) = &result
225                && err.is::<context_server::client::RequestCanceled>()
226            {
227                return Ok(());
228            }
229
230            let response = result?;
231
232            if response.is_error.unwrap_or_default() {
233                return Err(anyhow!(response.text_contents()));
234            }
235
236            Ok(())
237        })
238    }
239
240    fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
241        let mut sessions = self.sessions.borrow_mut();
242
243        if let Some(cancel_tx) = sessions
244            .get_mut(session_id)
245            .and_then(|session| session.cancel_tx.take())
246        {
247            cancel_tx.send(()).ok();
248        }
249    }
250}
251
252impl Drop for AcpConnection {
253    fn drop(&mut self) {
254        self.client.stop().log_err();
255    }
256}