@@ -22,7 +22,7 @@ use acp_thread::{AcpThread, AgentConnection};
pub struct AcpConnection {
agent_state: Rc<RefCell<acp::AgentState>>,
server_name: &'static str,
- client: Arc<context_server::ContextServer>,
+ context_server: Arc<context_server::ContextServer>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
_agent_state_task: Task<()>,
_session_update_task: Task<()>,
@@ -35,7 +35,7 @@ impl AcpConnection {
working_directory: Option<Arc<Path>>,
cx: &mut AsyncApp,
) -> Result<Self> {
- let client: Arc<ContextServer> = ContextServer::stdio(
+ let context_server: Arc<ContextServer> = ContextServer::stdio(
ContextServerId(format!("{}-mcp-server", server_name).into()),
ContextServerCommand {
path: command.path,
@@ -45,42 +45,9 @@ impl AcpConnection {
working_directory,
)
.into();
- ContextServer::start(client.clone(), cx).await?;
let (mut state_tx, mut state_rx) = watch::channel(acp::AgentState::default());
- let mcp_client = client.client().context("Failed to subscribe")?;
-
- mcp_client.on_notification(acp::AGENT_METHODS.agent_state, {
- move |notification, _cx| {
- log::trace!(
- "ACP Notification: {}",
- serde_json::to_string_pretty(¬ification).unwrap()
- );
-
- if let Some(state) =
- serde_json::from_value::<acp::AgentState>(notification).log_err()
- {
- state_tx.send(state).log_err();
- }
- }
- });
-
let (notification_tx, mut notification_rx) = mpsc::unbounded();
- mcp_client.on_notification(acp::AGENT_METHODS.session_update, {
- move |notification, _cx| {
- let notification_tx = notification_tx.clone();
- log::trace!(
- "ACP Notification: {}",
- serde_json::to_string_pretty(¬ification).unwrap()
- );
-
- if let Some(notification) =
- serde_json::from_value::<acp::SessionNotification>(notification).log_err()
- {
- notification_tx.unbounded_send(notification).ok();
- }
- }
- });
let sessions = Rc::new(RefCell::new(HashMap::default()));
let initial_state = state_rx.recv().await?;
@@ -104,9 +71,47 @@ impl AcpConnection {
}
});
+ context_server
+ .start_with_handlers(
+ vec![
+ (acp::AGENT_METHODS.agent_state, {
+ Box::new(move |notification, _cx| {
+ log::trace!(
+ "ACP Notification: {}",
+ serde_json::to_string_pretty(¬ification).unwrap()
+ );
+
+ if let Some(state) =
+ serde_json::from_value::<acp::AgentState>(notification).log_err()
+ {
+ state_tx.send(state).log_err();
+ }
+ })
+ }),
+ (acp::AGENT_METHODS.session_update, {
+ Box::new(move |notification, _cx| {
+ let notification_tx = notification_tx.clone();
+ log::trace!(
+ "ACP Notification: {}",
+ serde_json::to_string_pretty(¬ification).unwrap()
+ );
+
+ if let Some(notification) =
+ serde_json::from_value::<acp::SessionNotification>(notification)
+ .log_err()
+ {
+ notification_tx.unbounded_send(notification).ok();
+ }
+ })
+ }),
+ ],
+ cx,
+ )
+ .await?;
+
Ok(Self {
server_name,
- client,
+ context_server,
sessions,
agent_state,
_agent_state_task: agent_state_task,
@@ -152,7 +157,7 @@ impl AgentConnection for AcpConnection {
cwd: &Path,
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>> {
- let client = self.client.client();
+ let client = self.context_server.client();
let sessions = self.sessions.clone();
let cwd = cwd.to_path_buf();
cx.spawn(async move |cx| {
@@ -222,7 +227,7 @@ impl AgentConnection for AcpConnection {
}
fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
- let client = self.client.client();
+ let client = self.context_server.client();
cx.foreground_executor().spawn(async move {
let params = acp::AuthenticateArguments { method_id };
@@ -248,7 +253,7 @@ impl AgentConnection for AcpConnection {
params: agent_client_protocol::PromptArguments,
cx: &mut App,
) -> Task<Result<()>> {
- let client = self.client.client();
+ let client = self.context_server.client();
let sessions = self.sessions.clone();
cx.foreground_executor().spawn(async move {
@@ -305,6 +310,6 @@ impl AgentConnection for AcpConnection {
impl Drop for AcpConnection {
fn drop(&mut self) {
- self.client.stop().log_err();
+ self.context_server.stop().log_err();
}
}
@@ -441,14 +441,12 @@ impl Client {
Ok(())
}
- #[allow(unused)]
- pub fn on_notification<F>(&self, method: &'static str, f: F)
- where
- F: 'static + Send + FnMut(Value, AsyncApp),
- {
- self.notification_handlers
- .lock()
- .insert(method, Box::new(f));
+ pub fn on_notification(
+ &self,
+ method: &'static str,
+ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+ ) {
+ self.notification_handlers.lock().insert(method, f);
}
}
@@ -95,8 +95,28 @@ impl ContextServer {
self.client.read().clone()
}
- pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
- let client = match &self.configuration {
+ pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
+ self.initialize(self.new_client(cx)?).await
+ }
+
+ /// Starts the context server, making sure handlers are registered before initialization happens
+ pub async fn start_with_handlers(
+ &self,
+ notification_handlers: Vec<(
+ &'static str,
+ Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
+ )>,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let client = self.new_client(cx)?;
+ for (method, handler) in notification_handlers {
+ client.on_notification(method, handler);
+ }
+ self.initialize(client).await
+ }
+
+ fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
+ Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
@@ -113,8 +133,7 @@ impl ContextServer {
transport.clone(),
cx.clone(),
)?,
- };
- self.initialize(client).await
+ })
}
async fn initialize(&self, client: Client) -> Result<()> {
@@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params)
}
- pub fn on_notification<F>(&self, method: &'static str, f: F)
- where
- F: 'static + Send + FnMut(Value, AsyncApp),
- {
+ pub fn on_notification(
+ &self,
+ method: &'static str,
+ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+ ) {
self.inner.on_notification(method, f);
}
}