diff --git a/Cargo.lock b/Cargo.lock index ea9e6fc49fdf0d0bf886ba61fd6e309e8e561241..772fb2492b6cc55cd8c4efb6ecf6c53217f96a2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20203,6 +20203,7 @@ dependencies = [ "diagnostics", "editor", "env_logger 0.11.8", + "erased-serde", "extension", "extension_host", "extensions_ui", diff --git a/crates/agent_servers/src/codex.rs b/crates/agent_servers/src/codex.rs index 95f9b039983ef62e5a1ca4eb26bfbba3f8dfb53f..3bc135aa89b0328b1f410b8bab6f9217a59b55a1 100644 --- a/crates/agent_servers/src/codex.rs +++ b/crates/agent_servers/src/codex.rs @@ -5,6 +5,7 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use futures::channel::{mpsc, oneshot}; use itertools::Itertools; use project::Project; +use serde::de::DeserializeOwned; use settings::SettingsStore; use smol::stream::StreamExt; use std::cell::RefCell; @@ -29,6 +30,57 @@ use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection}; #[derive(Clone)] pub struct Codex; +pub struct CodexApproval; +impl context_server::types::Request for CodexApproval { + type Params = CodexApprovalRequest; + type Response = CodexApprovalResponse; + const METHOD: &'static str = "elicitation/create"; +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CodexApprovalRequest { + // These fields are required so that `params` + // conforms to ElicitRequestParams. + pub message: String, + // #[serde(rename = "requestedSchema")] + // pub requested_schema: ElicitRequestParamsRequestedSchema, + + // // These are additional fields the client can use to + // // correlate the request with the codex tool call. + // pub codex_elicitation: String, + // pub codex_mcp_tool_call_id: String, + // pub codex_event_id: String, + // pub codex_command: Vec, + // pub codex_cwd: PathBuf, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CodexApprovalResponse { + pub decision: ReviewDecision, +} + +/// User's decision in response to an ExecApprovalRequest. +#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ReviewDecision { + /// User has approved this command and the agent should execute it. + Approved, + + /// User has approved this command and wants to automatically approve any + /// future identical instances (`command` and `cwd` match exactly) for the + /// remainder of the session. + ApprovedForSession, + + /// User has denied this command and the agent should not execute it, but + /// it should continue the session and try something else. + #[default] + Denied, + + /// User has denied this command and the agent should not do anything until + /// the user's next command. + Abort, +} + impl AgentServer for Codex { fn name(&self) -> &'static str { "Codex" @@ -106,23 +158,26 @@ impl AgentServer for Codex { let (notification_tx, mut notification_rx) = mpsc::unbounded(); - codex_mcp_client + let client = codex_mcp_client .client() - .context("Failed to subscribe to server")? - .on_notification("codex/event", { - move |event, cx| { - let mut notification_tx = notification_tx.clone(); - cx.background_spawn(async move { - log::trace!("Notification: {:?}", event); - if let Some(event) = - serde_json::from_value::(event).log_err() - { - notification_tx.send(event.msg).await.log_err(); - } - }) - .detach(); - } - }); + .context("Failed to subscribe to server")?; + client.on_request::({ + move |elicitation: CodexApprovalRequest, cx| { + cx.spawn(async move |cx| anyhow::bail!("oops")) + } + }); + client.on_notification("codex/event", { + move |event, cx| { + let mut notification_tx = notification_tx.clone(); + cx.background_spawn(async move { + log::trace!("Notification: {:?}", event); + if let Some(event) = serde_json::from_value::(event).log_err() { + notification_tx.send(event.msg).await.log_err(); + } + }) + .detach(); + } + }); cx.new(|cx| { let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); diff --git a/crates/context_server/src/client.rs b/crates/context_server/src/client.rs index 960f9717e58f2fc1502d5030da0908cdd8b76e73..d87de882137be4dc1ec14d1c33a946a84ea688f3 100644 --- a/crates/context_server/src/client.rs +++ b/crates/context_server/src/client.rs @@ -36,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603; type ResponseHandler = Box)>; type NotificationHandler = Box; +type RequestHandler = Box; #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(untagged)] @@ -50,6 +51,7 @@ pub(crate) struct Client { outbound_tx: channel::Sender, name: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, #[allow(clippy::type_complexity)] #[allow(dead_code)] @@ -82,6 +84,15 @@ pub struct Request<'a, T> { pub params: T, } +#[derive(Serialize, Deserialize)] +pub struct AnyRequest<'a> { + pub jsonrpc: &'a str, + pub id: RequestId, + pub method: &'a str, + #[serde(skip_serializing_if = "is_null_value")] + pub params: Option<&'a RawValue>, +} + #[derive(Serialize, Deserialize)] struct AnyResponse<'a> { jsonrpc: &'a str, @@ -180,15 +191,23 @@ impl Client { Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default())); let response_handlers = Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default()))); + let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default())); let receive_input_task = cx.spawn({ let notification_handlers = notification_handlers.clone(); let response_handlers = response_handlers.clone(); + let request_handlers = request_handlers.clone(); let transport = transport.clone(); async move |cx| { - Self::handle_input(transport, notification_handlers, response_handlers, cx) - .log_err() - .await + Self::handle_input( + transport, + notification_handlers, + request_handlers, + response_handlers, + cx, + ) + .log_err() + .await } }); let receive_err_task = cx.spawn({ @@ -215,6 +234,7 @@ impl Client { server_id, notification_handlers, response_handlers, + request_handlers, name: server_name, next_id: Default::default(), outbound_tx, @@ -234,23 +254,39 @@ impl Client { async fn handle_input( transport: Arc, notification_handlers: Arc>>, + request_handlers: Arc>>, response_handlers: Arc>>>, cx: &mut AsyncApp, ) -> anyhow::Result<()> { let mut receiver = transport.receive(); while let Some(message) = receiver.next().await { + log::trace!("recv: {}", &message); if let Ok(response) = serde_json::from_str::(&message) { + dbg!("here!"); if let Some(handlers) = response_handlers.lock().as_mut() { if let Some(handler) = handlers.remove(&response.id) { handler(Ok(message.to_string())); } } + } else if let Some(request) = serde_json::from_str::(&message).log_err() { + dbg!("here!"); + let mut request_handlers = request_handlers.lock(); + if let Some(handler) = request_handlers.get_mut(request.method) { + handler( + request.id, + request.params.unwrap_or(RawValue::NULL), + cx.clone(), + ); + } } else if let Ok(notification) = serde_json::from_str::(&message) { + dbg!("here!"); let mut notification_handlers = notification_handlers.lock(); if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) { handler(notification.params.unwrap_or(Value::Null), cx.clone()); } + } else { + dbg!("WTF", &message); } } @@ -419,6 +455,79 @@ impl Client { .lock() .insert(method, Box::new(f)); } + + pub fn on_request(&self, mut f: F) + where + F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task>, + { + let outbound_tx = self.outbound_tx.clone(); + self.request_handlers.lock().insert( + R::METHOD, + Box::new(move |id, json, cx| { + let outbound_tx = outbound_tx.clone(); + match serde_json::from_str(json.get()) { + Ok(req) => { + let task = f(req, cx.clone()); + cx.foreground_executor() + .spawn(async move { + match task.await { + Ok(res) => { + outbound_tx + .send( + serde_json::to_string(&Response { + jsonrpc: JSON_RPC_VERSION, + id, + value: CspResult::Ok(Some(res)), + }) + .unwrap(), + ) + .await + .ok(); + } + Err(e) => { + outbound_tx + .send( + serde_json::to_string(&Response { + jsonrpc: JSON_RPC_VERSION, + id, + value: CspResult::<()>::Error(Some(Error { + code: -1, // todo!() + message: format!("{e}"), + })), + }) + .unwrap(), + ) + .await + .ok(); + } + } + }) + .detach(); + } + Err(e) => { + cx.foreground_executor() + .spawn(async move { + outbound_tx + .send( + serde_json::to_string(&Response { + jsonrpc: JSON_RPC_VERSION, + id, + value: CspResult::<()>::Error(Some(Error { + code: -1, // todo!() + message: format!("{e}"), + })), + }) + .unwrap(), + ) + .await + .ok(); + }) + .detach(); + } + } + }), + ); + } } impl fmt::Display for ContextServerId { diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 7263f502fa44b05b6bba72eda58c7ad84b52ebf7..287b1389820a474ba066c3cccc5d03dc9a45b8d0 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -7,7 +7,9 @@ use anyhow::Result; use futures::channel::oneshot; -use gpui::AsyncApp; +use gpui::{AsyncApp, Task}; +use serde::Serialize; +use serde::de::DeserializeOwned; use serde_json::Value; use crate::client::Client; @@ -118,4 +120,11 @@ impl InitializedContextServerProtocol { { self.inner.on_notification(method, f); } + + pub fn on_request(&self, f: F) + where + F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task>, + { + self.inner.on_request::(f); + } } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index e565aba26b4caae298a063b7cd2036f5a7ee648d..832f2be69101917be8cf68f09e1fb7dee8382d6c 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -160,6 +160,7 @@ zed_actions.workspace = true zeta.workspace = true zlog.workspace = true zlog_settings.workspace = true +erased-serde = "0.4.6" [target.'cfg(target_os = "windows")'.dependencies] windows.workspace = true