Cargo.lock 🔗
@@ -20203,6 +20203,7 @@ dependencies = [
"diagnostics",
"editor",
"env_logger 0.11.8",
+ "erased-serde",
"extension",
"extension_host",
"extensions_ui",
Conrad Irwin created
t
Cargo.lock | 1
crates/agent_servers/src/codex.rs | 87 +++++++++++++++++----
crates/context_server/src/client.rs | 115 ++++++++++++++++++++++++++++
crates/context_server/src/protocol.rs | 11 ++
crates/zed/Cargo.toml | 1
5 files changed, 195 insertions(+), 20 deletions(-)
@@ -20203,6 +20203,7 @@ dependencies = [
"diagnostics",
"editor",
"env_logger 0.11.8",
+ "erased-serde",
"extension",
"extension_host",
"extensions_ui",
@@ -5,6 +5,7 @@ use context_server::{ContextServer, ContextServerCommand, ContextServerId};
use futures::channel::{mpsc, oneshot};
use itertools::Itertools;
use project::Project;
+use serde::de::DeserializeOwned;
use settings::SettingsStore;
use smol::stream::StreamExt;
use std::cell::RefCell;
@@ -29,6 +30,57 @@ use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
#[derive(Clone)]
pub struct Codex;
+pub struct CodexApproval;
+impl context_server::types::Request for CodexApproval {
+ type Params = CodexApprovalRequest;
+ type Response = CodexApprovalResponse;
+ const METHOD: &'static str = "elicitation/create";
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodexApprovalRequest {
+ // These fields are required so that `params`
+ // conforms to ElicitRequestParams.
+ pub message: String,
+ // #[serde(rename = "requestedSchema")]
+ // pub requested_schema: ElicitRequestParamsRequestedSchema,
+
+ // // These are additional fields the client can use to
+ // // correlate the request with the codex tool call.
+ // pub codex_elicitation: String,
+ // pub codex_mcp_tool_call_id: String,
+ // pub codex_event_id: String,
+ // pub codex_command: Vec<String>,
+ // pub codex_cwd: PathBuf,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CodexApprovalResponse {
+ pub decision: ReviewDecision,
+}
+
+/// User's decision in response to an ExecApprovalRequest.
+#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+pub enum ReviewDecision {
+ /// User has approved this command and the agent should execute it.
+ Approved,
+
+ /// User has approved this command and wants to automatically approve any
+ /// future identical instances (`command` and `cwd` match exactly) for the
+ /// remainder of the session.
+ ApprovedForSession,
+
+ /// User has denied this command and the agent should not execute it, but
+ /// it should continue the session and try something else.
+ #[default]
+ Denied,
+
+ /// User has denied this command and the agent should not do anything until
+ /// the user's next command.
+ Abort,
+}
+
impl AgentServer for Codex {
fn name(&self) -> &'static str {
"Codex"
@@ -106,23 +158,26 @@ impl AgentServer for Codex {
let (notification_tx, mut notification_rx) = mpsc::unbounded();
- codex_mcp_client
+ let client = codex_mcp_client
.client()
- .context("Failed to subscribe to server")?
- .on_notification("codex/event", {
- move |event, cx| {
- let mut notification_tx = notification_tx.clone();
- cx.background_spawn(async move {
- log::trace!("Notification: {:?}", event);
- if let Some(event) =
- serde_json::from_value::<CodexEvent>(event).log_err()
- {
- notification_tx.send(event.msg).await.log_err();
- }
- })
- .detach();
- }
- });
+ .context("Failed to subscribe to server")?;
+ client.on_request::<CodexApproval, _>({
+ move |elicitation: CodexApprovalRequest, cx| {
+ cx.spawn(async move |cx| anyhow::bail!("oops"))
+ }
+ });
+ client.on_notification("codex/event", {
+ move |event, cx| {
+ let mut notification_tx = notification_tx.clone();
+ cx.background_spawn(async move {
+ log::trace!("Notification: {:?}", event);
+ if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
+ notification_tx.send(event.msg).await.log_err();
+ }
+ })
+ .detach();
+ }
+ });
cx.new(|cx| {
let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
@@ -36,6 +36,7 @@ pub const INTERNAL_ERROR: i32 = -32603;
type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
+type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
@@ -50,6 +51,7 @@ pub(crate) struct Client {
outbound_tx: channel::Sender<String>,
name: Arc<str>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
#[allow(clippy::type_complexity)]
#[allow(dead_code)]
@@ -82,6 +84,15 @@ pub struct Request<'a, T> {
pub params: T,
}
+#[derive(Serialize, Deserialize)]
+pub struct AnyRequest<'a> {
+ pub jsonrpc: &'a str,
+ pub id: RequestId,
+ pub method: &'a str,
+ #[serde(skip_serializing_if = "is_null_value")]
+ pub params: Option<&'a RawValue>,
+}
+
#[derive(Serialize, Deserialize)]
struct AnyResponse<'a> {
jsonrpc: &'a str,
@@ -180,15 +191,23 @@ impl Client {
Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
let response_handlers =
Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+ let request_handlers = Arc::new(Mutex::new(HashMap::<_, RequestHandler>::default()));
let receive_input_task = cx.spawn({
let notification_handlers = notification_handlers.clone();
let response_handlers = response_handlers.clone();
+ let request_handlers = request_handlers.clone();
let transport = transport.clone();
async move |cx| {
- Self::handle_input(transport, notification_handlers, response_handlers, cx)
- .log_err()
- .await
+ Self::handle_input(
+ transport,
+ notification_handlers,
+ request_handlers,
+ response_handlers,
+ cx,
+ )
+ .log_err()
+ .await
}
});
let receive_err_task = cx.spawn({
@@ -215,6 +234,7 @@ impl Client {
server_id,
notification_handlers,
response_handlers,
+ request_handlers,
name: server_name,
next_id: Default::default(),
outbound_tx,
@@ -234,23 +254,39 @@ impl Client {
async fn handle_input(
transport: Arc<dyn Transport>,
notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ request_handlers: Arc<Mutex<HashMap<&'static str, RequestHandler>>>,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
cx: &mut AsyncApp,
) -> anyhow::Result<()> {
let mut receiver = transport.receive();
while let Some(message) = receiver.next().await {
+ log::trace!("recv: {}", &message);
if let Ok(response) = serde_json::from_str::<AnyResponse>(&message) {
+ dbg!("here!");
if let Some(handlers) = response_handlers.lock().as_mut() {
if let Some(handler) = handlers.remove(&response.id) {
handler(Ok(message.to_string()));
}
}
+ } else if let Some(request) = serde_json::from_str::<AnyRequest>(&message).log_err() {
+ dbg!("here!");
+ let mut request_handlers = request_handlers.lock();
+ if let Some(handler) = request_handlers.get_mut(request.method) {
+ handler(
+ request.id,
+ request.params.unwrap_or(RawValue::NULL),
+ cx.clone(),
+ );
+ }
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
+ dbg!("here!");
let mut notification_handlers = notification_handlers.lock();
if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
handler(notification.params.unwrap_or(Value::Null), cx.clone());
}
+ } else {
+ dbg!("WTF", &message);
}
}
@@ -419,6 +455,79 @@ impl Client {
.lock()
.insert(method, Box::new(f));
}
+
+ pub fn on_request<R: crate::types::Request, F>(&self, mut f: F)
+ where
+ F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
+ {
+ let outbound_tx = self.outbound_tx.clone();
+ self.request_handlers.lock().insert(
+ R::METHOD,
+ Box::new(move |id, json, cx| {
+ let outbound_tx = outbound_tx.clone();
+ match serde_json::from_str(json.get()) {
+ Ok(req) => {
+ let task = f(req, cx.clone());
+ cx.foreground_executor()
+ .spawn(async move {
+ match task.await {
+ Ok(res) => {
+ outbound_tx
+ .send(
+ serde_json::to_string(&Response {
+ jsonrpc: JSON_RPC_VERSION,
+ id,
+ value: CspResult::Ok(Some(res)),
+ })
+ .unwrap(),
+ )
+ .await
+ .ok();
+ }
+ Err(e) => {
+ outbound_tx
+ .send(
+ serde_json::to_string(&Response {
+ jsonrpc: JSON_RPC_VERSION,
+ id,
+ value: CspResult::<()>::Error(Some(Error {
+ code: -1, // todo!()
+ message: format!("{e}"),
+ })),
+ })
+ .unwrap(),
+ )
+ .await
+ .ok();
+ }
+ }
+ })
+ .detach();
+ }
+ Err(e) => {
+ cx.foreground_executor()
+ .spawn(async move {
+ outbound_tx
+ .send(
+ serde_json::to_string(&Response {
+ jsonrpc: JSON_RPC_VERSION,
+ id,
+ value: CspResult::<()>::Error(Some(Error {
+ code: -1, // todo!()
+ message: format!("{e}"),
+ })),
+ })
+ .unwrap(),
+ )
+ .await
+ .ok();
+ })
+ .detach();
+ }
+ }
+ }),
+ );
+ }
}
impl fmt::Display for ContextServerId {
@@ -7,7 +7,9 @@
use anyhow::Result;
use futures::channel::oneshot;
-use gpui::AsyncApp;
+use gpui::{AsyncApp, Task};
+use serde::Serialize;
+use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::client::Client;
@@ -118,4 +120,11 @@ impl InitializedContextServerProtocol {
{
self.inner.on_notification(method, f);
}
+
+ pub fn on_request<R: crate::types::Request, F>(&self, f: F)
+ where
+ F: 'static + Send + FnMut(R::Params, AsyncApp) -> Task<Result<R::Response>>,
+ {
+ self.inner.on_request::<R, F>(f);
+ }
}
@@ -160,6 +160,7 @@ zed_actions.workspace = true
zeta.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
+erased-serde = "0.4.6"
[target.'cfg(target_os = "windows")'.dependencies]
windows.workspace = true