WIP elciication

Conrad Irwin created

t

Change summary

Cargo.lock                            |   1 
crates/agent_servers/src/codex.rs     |  87 +++++++++++++++++----
crates/context_server/src/client.rs   | 115 ++++++++++++++++++++++++++++
crates/context_server/src/protocol.rs |  11 ++
crates/zed/Cargo.toml                 |   1 
5 files changed, 195 insertions(+), 20 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -20203,6 +20203,7 @@ dependencies = [
  "diagnostics",
  "editor",
  "env_logger 0.11.8",
+ "erased-serde",
  "extension",
  "extension_host",
  "extensions_ui",

crates/agent_servers/src/codex.rs 🔗

@@ -5,6 +5,7 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use futures::channel::{mpsc, oneshot};
 use itertools::Itertools;
 use project::Project;
+use serde::de::DeserializeOwned;
 use settings::SettingsStore;
 use smol::stream::StreamExt;
 use std::cell::RefCell;
@@ -29,6 +30,57 @@ use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 #[derive(Clone)]
 pub struct Codex;
 
+pub struct CodexApproval;
+impl context_server::types::Request for CodexApproval {
+    type Params = CodexApprovalRequest;
+    type Response = CodexApprovalResponse;
+    const METHOD: &'static str = "elicitation/create";
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodexApprovalRequest {
+    // These fields are required so that `params`
+    // conforms to ElicitRequestParams.
+    pub message: String,
+    // #[serde(rename = "requestedSchema")]
+    // pub requested_schema: ElicitRequestParamsRequestedSchema,
+
+    // // These are additional fields the client can use to
+    // // correlate the request with the codex tool call.
+    // pub codex_elicitation: String,
+    // pub codex_mcp_tool_call_id: String,
+    // pub codex_event_id: String,
+    // pub codex_command: Vec<String>,
+    // pub codex_cwd: PathBuf,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodexApprovalResponse {
+    pub decision: ReviewDecision,
+}
+
+/// User's decision in response to an ExecApprovalRequest.
+#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+pub enum ReviewDecision {
+    /// User has approved this command and the agent should execute it.
+    Approved,
+
+    /// User has approved this command and wants to automatically approve any
+    /// future identical instances (`command` and `cwd` match exactly) for the
+    /// remainder of the session.
+    ApprovedForSession,
+
+    /// User has denied this command and the agent should not execute it, but
+    /// it should continue the session and try something else.
+    #[default]
+    Denied,
+
+    /// User has denied this command and the agent should not do anything until
+    /// the user's next command.
+    Abort,
+}
+
 impl AgentServer for Codex {
     fn name(&self) -> &'static str {
         "Codex"
@@ -106,23 +158,26 @@ impl AgentServer for Codex {
 
             let (notification_tx, mut notification_rx) = mpsc::unbounded();
 
-            codex_mcp_client
+            let client = codex_mcp_client
                 .client()
-                .context("Failed to subscribe to server")?
-                .on_notification("codex/event", {
-                    move |event, cx| {
-                        let mut notification_tx = notification_tx.clone();
-                        cx.background_spawn(async move {
-                            log::trace!("Notification: {:?}", event);
-                            if let Some(event) =
-                                serde_json::from_value::<CodexEvent>(event).log_err()
-                            {
-                                notification_tx.send(event.msg).await.log_err();
-                            }
-                        })
-                        .detach();
-                    }
-                });
+                .context("Failed to subscribe to server")?;
+            client.on_request::<CodexApproval, _>({
+                move |elicitation: CodexApprovalRequest, cx| {
+                    cx.spawn(async move |cx| anyhow::bail!("oops"))
+                }
+            });
+            client.on_notification("codex/event", {
+                move |event, cx| {
+                    let mut notification_tx = notification_tx.clone();
+                    cx.background_spawn(async move {
+                        log::trace!("Notification: {:?}", event);
+                        if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
+                            notification_tx.send(event.msg).await.log_err();
+                        }
+                    })
+                    .detach();
+                }
+            });
 
             cx.new(|cx| {
                 let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());

crates/context_server/src/client.rs 🔗

@@ -36,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
 
 type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
 type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
+type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
 
 #[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
 #[serde(untagged)]
@@ -50,6 +51,7 @@ pub(crate) struct Client {
     outbound_tx: channel::Sender<String>,
     name: Arc<str>,
     notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+    request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
     response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
     #[allow(clippy::type_complexity)]
     #[allow(dead_code)]
@@ -82,6 +84,15 @@ pub struct Request<'a, T> {
     pub params: T,
 }
 
+#[derive(Serialize, Deserialize)]
+pub struct AnyRequest<'a> {
+    pub jsonrpc: &'a str,
+    pub id: RequestId,
+    pub method: &'a str,
+    #[serde(skip_serializing_if = "is_null_value")]
+    pub params: Option<&'a RawValue>,
+}
+
 #[derive(Serialize, Deserialize)]
 struct AnyResponse<'a> {
     jsonrpc: &'a str,
@@ -180,15 +191,23 @@ impl Client {
             Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
         let response_handlers =
             Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+        let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
 
         let receive_input_task = cx.spawn({
             let notification_handlers = notification_handlers.clone();
             let response_handlers = response_handlers.clone();
+            let request_handlers = request_handlers.clone();
             let transport = transport.clone();
             async move |cx| {
-                Self::handle_input(transport, notification_handlers, response_handlers, cx)
-                    .log_err()
-                    .await
+                Self::handle_input(
+                    transport,
+                    notification_handlers,
+                    request_handlers,
+                    response_handlers,
+                    cx,
+                )
+                .log_err()
+                .await
             }
         });
         let receive_err_task = cx.spawn({
@@ -215,6 +234,7 @@ impl Client {
             server_id,
             notification_handlers,
             response_handlers,
+            request_handlers,
             name: server_name,
             next_id: Default::default(),
             outbound_tx,
@@ -234,23 +254,39 @@ impl Client {
     async fn handle_input(
         transport: Arc<dyn Transport>,
         notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+        request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
         response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
         cx: &mut AsyncApp,
     ) -> anyhow::Result<()> {
         let mut receiver = transport.receive();
 
         while let Some(message) = receiver.next().await {
+            log::trace!("recv: {}", &message);
             if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
+                dbg!("here!");
                 if let Some(handlers) = response_handlers.lock().as_mut() {
                     if let Some(handler) = handlers.remove(&response.id) {
                         handler(Ok(message.to_string()));
                     }
                 }
+            } else if let Some(request) = serde_json::from_str::<AnyRequest>(&message).log_err() {
+                dbg!("here!");
+                let mut request_handlers = request_handlers.lock();
+                if let Some(handler) = request_handlers.get_mut(request.method) {
+                    handler(
+                        request.id,
+                        request.params.unwrap_or(RawValue::NULL),
+                        cx.clone(),
+                    );
+                }
             } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
+                dbg!("here!");
                 let mut notification_handlers = notification_handlers.lock();
                 if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
                     handler(notification.params.unwrap_or(Value::Null), cx.clone());
                 }
+            } else {
+                dbg!("WTF", &message);
             }
         }
 
@@ -419,6 +455,79 @@ impl Client {
             .lock()
             .insert(method, Box::new(f));
     }
+
+    pub fn on_request<R: crate::types::Request, F>(&self, mut f: F)
+    where
+        F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
+    {
+        let outbound_tx = self.outbound_tx.clone();
+        self.request_handlers.lock().insert(
+            R::METHOD,
+            Box::new(move |id, json, cx| {
+                let outbound_tx = outbound_tx.clone();
+                match serde_json::from_str(json.get()) {
+                    Ok(req) => {
+                        let task = f(req, cx.clone());
+                        cx.foreground_executor()
+                            .spawn(async move {
+                                match task.await {
+                                    Ok(res) => {
+                                        outbound_tx
+                                            .send(
+                                                serde_json::to_string(&Response {
+                                                    jsonrpc: JSON_RPC_VERSION,
+                                                    id,
+                                                    value: CspResult::Ok(Some(res)),
+                                                })
+                                                .unwrap(),
+                                            )
+                                            .await
+                                            .ok();
+                                    }
+                                    Err(e) => {
+                                        outbound_tx
+                                            .send(
+                                                serde_json::to_string(&Response {
+                                                    jsonrpc: JSON_RPC_VERSION,
+                                                    id,
+                                                    value: CspResult::<()>::Error(Some(Error {
+                                                        code: -1, // todo!()
+                                                        message: format!("{e}"),
+                                                    })),
+                                                })
+                                                .unwrap(),
+                                            )
+                                            .await
+                                            .ok();
+                                    }
+                                }
+                            })
+                            .detach();
+                    }
+                    Err(e) => {
+                        cx.foreground_executor()
+                            .spawn(async move {
+                                outbound_tx
+                                    .send(
+                                        serde_json::to_string(&Response {
+                                            jsonrpc: JSON_RPC_VERSION,
+                                            id,
+                                            value: CspResult::<()>::Error(Some(Error {
+                                                code: -1, // todo!()
+                                                message: format!("{e}"),
+                                            })),
+                                        })
+                                        .unwrap(),
+                                    )
+                                    .await
+                                    .ok();
+                            })
+                            .detach();
+                    }
+                }
+            }),
+        );
+    }
 }
 
 impl fmt::Display for ContextServerId {

crates/context_server/src/protocol.rs 🔗

@@ -7,7 +7,9 @@
 
 use anyhow::Result;
 use futures::channel::oneshot;
-use gpui::AsyncApp;
+use gpui::{AsyncApp, Task};
+use serde::Serialize;
+use serde::de::DeserializeOwned;
 use serde_json::Value;
 
 use crate::client::Client;
@@ -118,4 +120,11 @@ impl InitializedContextServerProtocol {
     {
         self.inner.on_notification(method, f);
     }
+
+    pub fn on_request<R: crate::types::Request, F>(&self, f: F)
+    where
+        F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
+    {
+        self.inner.on_request::<R, F>(f);
+    }
 }

crates/zed/Cargo.toml 🔗

@@ -160,6 +160,7 @@ zed_actions.workspace = true
 zeta.workspace = true
 zlog.workspace = true
 zlog_settings.workspace = true
+erased-serde = "0.4.6"
 
 [target.'cfg(target_os = "windows")'.dependencies]
 windows.workspace = true