diff --git a/crates/agent_servers/src/acp_connection.rs b/crates/agent_servers/src/acp_connection.rs index c19a145196038744676a8922518101b71b6366cd..bfb4d8b40f83479b4d9ae37fbc3ac48be6f4bb87 100644 --- a/crates/agent_servers/src/acp_connection.rs +++ b/crates/agent_servers/src/acp_connection.rs @@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection, AuthRequired}; pub struct AcpConnection { auth_methods: Rc>>, server_name: &'static str, - client: Arc, + context_server: Arc, sessions: Rc>>, _session_update_task: Task<()>, } @@ -34,7 +34,7 @@ impl AcpConnection { working_directory: Option>, cx: &mut AsyncApp, ) -> Result { - let client: Arc = ContextServer::stdio( + let context_server: Arc = ContextServer::stdio( ContextServerId(format!("{}-mcp-server", server_name).into()), ContextServerCommand { path: command.path, @@ -44,26 +44,8 @@ impl AcpConnection { working_directory, ) .into(); - ContextServer::start(client.clone(), cx).await?; - - let mcp_client = client.client().context("Failed to subscribe")?; let (notification_tx, mut notification_rx) = mpsc::unbounded(); - mcp_client.on_notification(acp::AGENT_METHODS.session_update, { - move |notification, _cx| { - let notification_tx = notification_tx.clone(); - log::trace!( - "ACP Notification: {}", - serde_json::to_string_pretty(¬ification).unwrap() - ); - - if let Some(notification) = - serde_json::from_value::(notification).log_err() - { - notification_tx.unbounded_send(notification).ok(); - } - } - }); let sessions = Rc::new(RefCell::new(HashMap::default())); @@ -76,10 +58,32 @@ impl AcpConnection { } }); + context_server + .start_with_handlers( + vec![(acp::AGENT_METHODS.session_update, { + Box::new(move |notification, _cx| { + let notification_tx = notification_tx.clone(); + log::trace!( + "ACP Notification: {}", + serde_json::to_string_pretty(¬ification).unwrap() + ); + + if let Some(notification) = + serde_json::from_value::(notification) + .log_err() + { + notification_tx.unbounded_send(notification).ok(); + } + }) + })], + cx, + ) + .await?; + Ok(Self { auth_methods: Default::default(), server_name, - client, + context_server, sessions, _session_update_task: session_update_handler_task, }) @@ -123,7 +127,7 @@ impl AgentConnection for AcpConnection { cwd: &Path, cx: &mut AsyncApp, ) -> Task>> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); let auth_methods = self.auth_methods.clone(); let cwd = cwd.to_path_buf(); @@ -200,7 +204,7 @@ impl AgentConnection for AcpConnection { } fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); cx.foreground_executor().spawn(async move { let params = acp::AuthenticateArguments { method_id }; @@ -226,7 +230,7 @@ impl AgentConnection for AcpConnection { params: agent_client_protocol::PromptArguments, cx: &mut App, ) -> Task> { - let client = self.client.client(); + let client = self.context_server.client(); let sessions = self.sessions.clone(); cx.foreground_executor().spawn(async move { @@ -283,6 +287,6 @@ impl AgentConnection for AcpConnection { impl Drop for AcpConnection { fn drop(&mut self) { - self.client.stop().log_err(); + self.context_server.stop().log_err(); } } diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 1eb29bbbf9d61b6139e8d9a1d5fffd2836f55c8a..65283afa87d94fae3ec51f8a89574713080bded2 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -441,14 +441,12 @@ impl Client { Ok(()) } - #[allow(unused)] - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { - self.notification_handlers - .lock() - .insert(method, Box::new(f)); + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { + self.notification_handlers.lock().insert(method, f); } } diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index e76e7972f76a90743b0b34609f4407749660e50f..34fa29678d5d68f864de7d9df3bef82d4c667f05 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -95,8 +95,28 @@ impl ContextServer { self.client.read().clone() } - pub async fn start(self: Arc, cx: &AsyncApp) -> Result<()> { - let client = match &self.configuration { + pub async fn start(&self, cx: &AsyncApp) -> Result<()> { + self.initialize(self.new_client(cx)?).await + } + + /// Starts the context server, making sure handlers are registered before initialization happens + pub async fn start_with_handlers( + &self, + notification_handlers: Vec<( + &'static str, + Box, + )>, + cx: &AsyncApp, + ) -> Result<()> { + let client = self.new_client(cx)?; + for (method, handler) in notification_handlers { + client.on_notification(method, handler); + } + self.initialize(client).await + } + + fn new_client(&self, cx: &AsyncApp) -> Result { + Ok(match &self.configuration { ContextServerTransport::Stdio(command, working_directory) => Client::stdio( client::ContextServerId(self.id.0.clone()), client::ModelContextServerBinary { @@ -113,8 +133,7 @@ impl ContextServer { transport.clone(), cx.clone(), )?, - }; - self.initialize(client).await + }) } async fn initialize(&self, client: Client) -> Result<()> { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 9ccbc8a55380c1c4c222894579f3b4f56d57468a..5355f20f620b5bed76bf945e863fdb5cbcc2ff43 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -115,10 +115,11 @@ impl InitializedContextServerProtocol { self.inner.notify(T::METHOD, params) } - pub fn on_notification(&self, method: &'static str, f: F) - where - F: 'static + Send + FnMut(Value, AsyncApp), - { + pub fn on_notification( + &self, + method: &'static str, + f: Box, + ) { self.inner.on_notification(method, f); } }