acp_connection.rs

  1use agent_client_protocol as acp;
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use context_server::listener::McpServerTool;
  5use context_server::types::requests;
  6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  7use futures::channel::{mpsc, oneshot};
  8use project::Project;
  9use smol::stream::StreamExt as _;
 10use std::cell::RefCell;
 11use std::rc::Rc;
 12use std::{path::Path, sync::Arc};
 13use util::ResultExt;
 14
 15use anyhow::{Context, Result};
 16use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 17
 18use crate::mcp_server::ZedMcpServer;
 19use crate::{AgentServerCommand, mcp_server};
 20use acp_thread::{AcpThread, AgentConnection, AuthRequired};
 21
 22pub struct AcpConnection {
 23    auth_methods: Rc<RefCell<Vec<acp::AuthMethod>>>,
 24    server_name: &'static str,
 25    context_server: Arc<context_server::ContextServer>,
 26    sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 27    _session_update_task: Task<()>,
 28}
 29
 30impl AcpConnection {
 31    pub async fn stdio(
 32        server_name: &'static str,
 33        command: AgentServerCommand,
 34        working_directory: Option<Arc<Path>>,
 35        cx: &mut AsyncApp,
 36    ) -> Result<Self> {
 37        let context_server: Arc<ContextServer> = ContextServer::stdio(
 38            ContextServerId(format!("{}-mcp-server", server_name).into()),
 39            ContextServerCommand {
 40                path: command.path,
 41                args: command.args,
 42                env: command.env,
 43            },
 44            working_directory,
 45        )
 46        .into();
 47
 48        let (notification_tx, mut notification_rx) = mpsc::unbounded();
 49
 50        let sessions = Rc::new(RefCell::new(HashMap::default()));
 51
 52        let session_update_handler_task = cx.spawn({
 53            let sessions = sessions.clone();
 54            async move |cx| {
 55                while let Some(notification) = notification_rx.next().await {
 56                    Self::handle_session_notification(notification, sessions.clone(), cx)
 57                }
 58            }
 59        });
 60
 61        context_server
 62            .start_with_handlers(
 63                vec![(acp::AGENT_METHODS.session_update, {
 64                    Box::new(move |notification, _cx| {
 65                        let notification_tx = notification_tx.clone();
 66                        log::trace!(
 67                            "ACP Notification: {}",
 68                            serde_json::to_string_pretty(&notification).unwrap()
 69                        );
 70
 71                        if let Some(notification) =
 72                            serde_json::from_value::<acp::SessionNotification>(notification)
 73                                .log_err()
 74                        {
 75                            notification_tx.unbounded_send(notification).ok();
 76                        }
 77                    })
 78                })],
 79                cx,
 80            )
 81            .await?;
 82
 83        Ok(Self {
 84            auth_methods: Default::default(),
 85            server_name,
 86            context_server,
 87            sessions,
 88            _session_update_task: session_update_handler_task,
 89        })
 90    }
 91
 92    pub fn handle_session_notification(
 93        notification: acp::SessionNotification,
 94        threads: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
 95        cx: &mut AsyncApp,
 96    ) {
 97        let threads = threads.borrow();
 98        let Some(thread) = threads
 99            .get(&notification.session_id)
100            .and_then(|session| session.thread.upgrade())
101        else {
102            log::error!(
103                "Thread not found for session ID: {}",
104                notification.session_id
105            );
106            return;
107        };
108
109        thread
110            .update(cx, |thread, cx| {
111                thread.handle_session_update(notification.update, cx)
112            })
113            .log_err();
114    }
115}
116
117pub struct AcpSession {
118    thread: WeakEntity<AcpThread>,
119    cancel_tx: Option<oneshot::Sender<()>>,
120    _mcp_server: ZedMcpServer,
121}
122
123impl AgentConnection for AcpConnection {
124    fn new_thread(
125        self: Rc<Self>,
126        project: Entity<Project>,
127        cwd: &Path,
128        cx: &mut AsyncApp,
129    ) -> Task<Result<Entity<AcpThread>>> {
130        let client = self.context_server.client();
131        let sessions = self.sessions.clone();
132        let auth_methods = self.auth_methods.clone();
133        let cwd = cwd.to_path_buf();
134        cx.spawn(async move |cx| {
135            let client = client.context("MCP server is not initialized yet")?;
136            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
137
138            let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
139
140            let response = client
141                .request::<requests::CallTool>(context_server::types::CallToolParams {
142                    name: acp::AGENT_METHODS.new_session.into(),
143                    arguments: Some(serde_json::to_value(acp::NewSessionArguments {
144                        mcp_servers: vec![mcp_server.server_config()?],
145                        client_tools: acp::ClientTools {
146                            request_permission: Some(acp::McpToolId {
147                                mcp_server: mcp_server::SERVER_NAME.into(),
148                                tool_name: mcp_server::RequestPermissionTool::NAME.into(),
149                            }),
150                            read_text_file: Some(acp::McpToolId {
151                                mcp_server: mcp_server::SERVER_NAME.into(),
152                                tool_name: mcp_server::ReadTextFileTool::NAME.into(),
153                            }),
154                            write_text_file: Some(acp::McpToolId {
155                                mcp_server: mcp_server::SERVER_NAME.into(),
156                                tool_name: mcp_server::WriteTextFileTool::NAME.into(),
157                            }),
158                        },
159                        cwd,
160                    })?),
161                    meta: None,
162                })
163                .await?;
164
165            if response.is_error.unwrap_or_default() {
166                return Err(anyhow!(response.text_contents()));
167            }
168
169            let result = serde_json::from_value::<acp::NewSessionOutput>(
170                response.structured_content.context("Empty response")?,
171            )?;
172
173            auth_methods.replace(result.auth_methods);
174
175            let Some(session_id) = result.session_id else {
176                anyhow::bail!(AuthRequired);
177            };
178
179            let thread = cx.new(|cx| {
180                AcpThread::new(
181                    self.server_name,
182                    self.clone(),
183                    project,
184                    session_id.clone(),
185                    cx,
186                )
187            })?;
188
189            thread_tx.send(thread.downgrade())?;
190
191            let session = AcpSession {
192                thread: thread.downgrade(),
193                cancel_tx: None,
194                _mcp_server: mcp_server,
195            };
196            sessions.borrow_mut().insert(session_id, session);
197
198            Ok(thread)
199        })
200    }
201
202    fn auth_methods(&self) -> Vec<agent_client_protocol::AuthMethod> {
203        self.auth_methods.borrow().clone()
204    }
205
206    fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
207        let client = self.context_server.client();
208        cx.foreground_executor().spawn(async move {
209            let params = acp::AuthenticateArguments { method_id };
210
211            let response = client
212                .context("MCP server is not initialized yet")?
213                .request::<requests::CallTool>(context_server::types::CallToolParams {
214                    name: acp::AGENT_METHODS.authenticate.into(),
215                    arguments: Some(serde_json::to_value(params)?),
216                    meta: None,
217                })
218                .await?;
219
220            if response.is_error.unwrap_or_default() {
221                Err(anyhow!(response.text_contents()))
222            } else {
223                Ok(())
224            }
225        })
226    }
227
228    fn prompt(
229        &self,
230        params: agent_client_protocol::PromptArguments,
231        cx: &mut App,
232    ) -> Task<Result<()>> {
233        let client = self.context_server.client();
234        let sessions = self.sessions.clone();
235
236        cx.foreground_executor().spawn(async move {
237            let client = client.context("MCP server is not initialized yet")?;
238
239            let (new_cancel_tx, cancel_rx) = oneshot::channel();
240            {
241                let mut sessions = sessions.borrow_mut();
242                let session = sessions
243                    .get_mut(&params.session_id)
244                    .context("Session not found")?;
245                session.cancel_tx.replace(new_cancel_tx);
246            }
247
248            let result = client
249                .request_with::<requests::CallTool>(
250                    context_server::types::CallToolParams {
251                        name: acp::AGENT_METHODS.prompt.into(),
252                        arguments: Some(serde_json::to_value(params)?),
253                        meta: None,
254                    },
255                    Some(cancel_rx),
256                    None,
257                )
258                .await;
259
260            if let Err(err) = &result
261                && err.is::<context_server::client::RequestCanceled>()
262            {
263                return Ok(());
264            }
265
266            let response = result?;
267
268            if response.is_error.unwrap_or_default() {
269                return Err(anyhow!(response.text_contents()));
270            }
271
272            Ok(())
273        })
274    }
275
276    fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
277        let mut sessions = self.sessions.borrow_mut();
278
279        if let Some(cancel_tx) = sessions
280            .get_mut(session_id)
281            .and_then(|session| session.cancel_tx.take())
282        {
283            cancel_tx.send(()).ok();
284        }
285    }
286}
287
288impl Drop for AcpConnection {
289    fn drop(&mut self) {
290        self.context_server.stop().log_err();
291    }
292}