Stream text

Agus Zubiaga created

Change summary

crates/agent_servers/src/codex.rs     | 91 ++++++++++++++++++++++++++--
crates/context_server/src/client.rs   |  1 
crates/context_server/src/protocol.rs |  9 ++
3 files changed, 93 insertions(+), 8 deletions(-)

Detailed changes

crates/agent_servers/src/codex.rs 🔗

@@ -2,17 +2,21 @@ use collections::HashMap;
 use context_server::types::CallToolParams;
 use context_server::types::requests::CallTool;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
+use futures::channel::mpsc;
 use project::Project;
 use settings::SettingsStore;
+use smol::stream::StreamExt;
 use std::cell::RefCell;
 use std::path::{Path, PathBuf};
 use std::rc::Rc;
 use std::sync::Arc;
 
-use agentic_coding_protocol::{self as acp, AnyAgentRequest, AnyAgentResult, ProtocolVersion};
+use agentic_coding_protocol::{
+    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
+};
 use anyhow::{Context, Result, anyhow};
 use futures::future::LocalBoxFuture;
-use futures::{AsyncWriteExt, FutureExt};
+use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 use gpui::{App, AppContext, Entity, Task};
 use serde::{Deserialize, Serialize};
 use util::ResultExt;
@@ -101,15 +105,47 @@ impl AgentServer for Codex {
             ContextServer::start(codex_mcp_client.clone(), cx).await?;
             // todo! stop
 
+            let (notification_tx, mut notification_rx) = mpsc::unbounded();
+
+            codex_mcp_client
+                .client()
+                .context("Failed to subscribe to server")?
+                .on_notification("codex/event", {
+                    move |event, cx| {
+                        let mut notification_tx = notification_tx.clone();
+                        cx.background_spawn(async move {
+                            log::trace!("Notification: {:?}", event);
+                            if let Some(event) =
+                                serde_json::from_value::<CodexEvent>(event).log_err()
+                            {
+                                notification_tx.send(event.msg).await.log_err();
+                            }
+                        })
+                        .detach();
+                    }
+                });
+
             cx.new(|cx| {
                 // todo! handle notifications
                 let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
                 delegate_tx.send(Some(delegate.clone())).log_err();
 
+                let handler_task = cx.spawn({
+                    let delegate = delegate.clone();
+                    async move |_, _cx| {
+                        while let Some(notification) = notification_rx.next().await {
+                            CodexAgentConnection::handle_acp_notification(&delegate, notification)
+                                .await
+                                .log_err();
+                        }
+                    }
+                });
+
                 let connection = CodexAgentConnection {
                     root_dir,
-                    codex_mcp_client,
-                    _zed_mcp_server: zed_mcp_server,
+                    codex_mcp: codex_mcp_client,
+                    _handler_task: handler_task,
+                    _zed_mcp: zed_mcp_server,
                 };
 
                 acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
@@ -124,7 +160,7 @@ impl AgentConnection for CodexAgentConnection {
         &self,
         params: AnyAgentRequest,
     ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
-        let client = self.codex_mcp_client.client();
+        let client = self.codex_mcp.client();
         let root_dir = self.root_dir.clone();
         async move {
             let client = client.context("Codex MCP server is not initialized")?;
@@ -176,9 +212,32 @@ impl AgentConnection for CodexAgentConnection {
 }
 
 struct CodexAgentConnection {
-    codex_mcp_client: Arc<context_server::ContextServer>,
+    codex_mcp: Arc<context_server::ContextServer>,
     root_dir: PathBuf,
-    _zed_mcp_server: ZedMcpServer,
+    _handler_task: Task<()>,
+    _zed_mcp: ZedMcpServer,
+}
+
+impl CodexAgentConnection {
+    async fn handle_acp_notification(
+        delegate: &AcpClientDelegate,
+        event: AcpNotification,
+    ) -> Result<()> {
+        match event {
+            AcpNotification::AgentMessage(message) => {
+                delegate
+                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
+                        chunk: acp::AssistantMessageChunk::Text {
+                            text: message.message,
+                        },
+                    })
+                    .await?;
+            }
+            AcpNotification::Other => {}
+        }
+
+        Ok(())
+    }
 }
 
 /// todo! use types from h2a crate when we have one
@@ -189,3 +248,21 @@ pub(crate) struct CodexToolCallParam {
     pub prompt: String,
     pub cwd: PathBuf,
 }
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+struct CodexEvent {
+    pub msg: AcpNotification,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum AcpNotification {
+    AgentMessage(AgentMessageEvent),
+    #[serde(other)]
+    Other,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AgentMessageEvent {
+    pub message: String,
+}

crates/context_server/src/client.rs 🔗

@@ -243,7 +243,6 @@ impl Client {
                     }
                 }
             } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
-                dbg!(&notification);
                 let mut notification_handlers = notification_handlers.lock();
                 if let Some(handler) = notification_handlers.get_mut(notification.method.as_str()) {
                     handler(notification.params.unwrap_or(Value::Null), cx.clone());

crates/context_server/src/protocol.rs 🔗

@@ -6,6 +6,8 @@
 //! of messages.
 
 use anyhow::Result;
+use gpui::AsyncApp;
+use serde_json::Value;
 
 use crate::client::Client;
 use crate::types::{self, Notification, Request};
@@ -98,4 +100,11 @@ impl InitializedContextServerProtocol {
     pub fn notify<T: Notification>(&self, params: T::Params) -> Result<()> {
         self.inner.notify(T::METHOD, params)
     }
+
+    pub fn on_notification<F>(&self, method: &'static str, f: F)
+    where
+        F: 'static + Send + FnMut(Value, AsyncApp),
+    {
+        self.inner.on_notification(method, f);
+    }
 }