codex.rs

  1use collections::HashMap;
  2use context_server::types::CallToolParams;
  3use context_server::types::requests::CallTool;
  4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  5use futures::channel::mpsc;
  6use project::Project;
  7use settings::SettingsStore;
  8use smol::stream::StreamExt;
  9use std::cell::RefCell;
 10use std::path::{Path, PathBuf};
 11use std::rc::Rc;
 12use std::sync::Arc;
 13
 14use agentic_coding_protocol::{
 15    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
 16};
 17use anyhow::{Context, Result, anyhow};
 18use futures::future::LocalBoxFuture;
 19use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 20use gpui::{App, AppContext, Entity, Task};
 21use serde::{Deserialize, Serialize};
 22use util::ResultExt;
 23
 24use crate::mcp_server::{McpConfig, ZedMcpServer};
 25use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 26use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 27
 28#[derive(Clone)]
 29pub struct Codex;
 30
 31impl AgentServer for Codex {
 32    fn name(&self) -> &'static str {
 33        "Codex"
 34    }
 35
 36    fn empty_state_headline(&self) -> &'static str {
 37        self.name()
 38    }
 39
 40    fn empty_state_message(&self) -> &'static str {
 41        ""
 42    }
 43
 44    fn logo(&self) -> ui::IconName {
 45        ui::IconName::AiOpenAi
 46    }
 47
 48    fn supports_always_allow(&self) -> bool {
 49        false
 50    }
 51
 52    fn new_thread(
 53        &self,
 54        root_dir: &Path,
 55        project: &Entity<Project>,
 56        cx: &mut App,
 57    ) -> Task<Result<Entity<AcpThread>>> {
 58        let project = project.clone();
 59        let root_dir = root_dir.to_path_buf();
 60        let title = self.name().into();
 61        cx.spawn(async move |cx| {
 62            let (mut delegate_tx, delegate_rx) = watch::channel(None);
 63            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
 64
 65            let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
 66
 67            let mut mcp_servers = HashMap::default();
 68            mcp_servers.insert(
 69                crate::mcp_server::SERVER_NAME.to_string(),
 70                zed_mcp_server.server_config()?,
 71            );
 72            let mcp_config = McpConfig { mcp_servers };
 73
 74            // todo! pass zed mcp server to codex tool
 75            let mcp_config_file = tempfile::NamedTempFile::new()?;
 76            let (mcp_config_file, _mcp_config_path) = mcp_config_file.into_parts();
 77
 78            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 79            mcp_config_file
 80                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 81                .await?;
 82            mcp_config_file.flush().await?;
 83
 84            let settings = cx.read_global(|settings: &SettingsStore, _| {
 85                settings.get::<AllAgentServersSettings>(None).codex.clone()
 86            })?;
 87
 88            let Some(command) =
 89                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 90            else {
 91                anyhow::bail!("Failed to find codex binary");
 92            };
 93
 94            let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
 95                ContextServerId("codex-mcp-server".into()),
 96                ContextServerCommand {
 97                    // todo! should we change ContextServerCommand to take a PathBuf?
 98                    path: command.path.to_string_lossy().to_string(),
 99                    args: command.args,
100                    env: command.env,
101                },
102            )
103            .into();
104
105            ContextServer::start(codex_mcp_client.clone(), cx).await?;
106            // todo! stop
107
108            let (notification_tx, mut notification_rx) = mpsc::unbounded();
109
110            codex_mcp_client
111                .client()
112                .context("Failed to subscribe to server")?
113                .on_notification("codex/event", {
114                    move |event, cx| {
115                        let mut notification_tx = notification_tx.clone();
116                        cx.background_spawn(async move {
117                            log::trace!("Notification: {:?}", event);
118                            if let Some(event) =
119                                serde_json::from_value::<CodexEvent>(event).log_err()
120                            {
121                                notification_tx.send(event.msg).await.log_err();
122                            }
123                        })
124                        .detach();
125                    }
126                });
127
128            cx.new(|cx| {
129                // todo! handle notifications
130                let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
131                delegate_tx.send(Some(delegate.clone())).log_err();
132
133                let handler_task = cx.spawn({
134                    let delegate = delegate.clone();
135                    async move |_, _cx| {
136                        while let Some(notification) = notification_rx.next().await {
137                            CodexAgentConnection::handle_acp_notification(&delegate, notification)
138                                .await
139                                .log_err();
140                        }
141                    }
142                });
143
144                let connection = CodexAgentConnection {
145                    root_dir,
146                    codex_mcp: codex_mcp_client,
147                    _handler_task: handler_task,
148                    _zed_mcp: zed_mcp_server,
149                };
150
151                acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
152            })
153        })
154    }
155}
156
157impl AgentConnection for CodexAgentConnection {
158    /// Send a request to the agent and wait for a response.
159    fn request_any(
160        &self,
161        params: AnyAgentRequest,
162    ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
163        let client = self.codex_mcp.client();
164        let root_dir = self.root_dir.clone();
165        async move {
166            let client = client.context("Codex MCP server is not initialized")?;
167
168            match params {
169                // todo: consider sending an empty request so we get the init response?
170                AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
171                    acp::InitializeResponse {
172                        is_authenticated: true,
173                        protocol_version: ProtocolVersion::latest(),
174                    },
175                )),
176                AnyAgentRequest::AuthenticateParams(_) => {
177                    Err(anyhow!("Authentication not supported"))
178                }
179                AnyAgentRequest::SendUserMessageParams(message) => {
180                    client
181                        .request::<CallTool>(CallToolParams {
182                            name: "codex".into(),
183                            arguments: Some(serde_json::to_value(CodexToolCallParam {
184                                prompt: message
185                                    .chunks
186                                    .into_iter()
187                                    .filter_map(|chunk| match chunk {
188                                        acp::UserMessageChunk::Text { text } => Some(text),
189                                        acp::UserMessageChunk::Path { .. } => {
190                                            // todo!
191                                            None
192                                        }
193                                    })
194                                    .collect(),
195                                cwd: root_dir,
196                            })?),
197                            meta: None,
198                        })
199                        .await?;
200
201                    Ok(AnyAgentResult::SendUserMessageResponse(
202                        acp::SendUserMessageResponse,
203                    ))
204                }
205                AnyAgentRequest::CancelSendMessageParams(_) => Ok(
206                    AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse),
207                ),
208            }
209        }
210        .boxed_local()
211    }
212}
213
214struct CodexAgentConnection {
215    codex_mcp: Arc<context_server::ContextServer>,
216    root_dir: PathBuf,
217    _handler_task: Task<()>,
218    _zed_mcp: ZedMcpServer,
219}
220
221impl CodexAgentConnection {
222    async fn handle_acp_notification(
223        delegate: &AcpClientDelegate,
224        event: AcpNotification,
225    ) -> Result<()> {
226        match event {
227            AcpNotification::AgentMessage(message) => {
228                delegate
229                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
230                        chunk: acp::AssistantMessageChunk::Text {
231                            text: message.message,
232                        },
233                    })
234                    .await?;
235            }
236            AcpNotification::Other => {}
237        }
238
239        Ok(())
240    }
241}
242
243/// todo! use types from h2a crate when we have one
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
246#[serde(rename_all = "kebab-case")]
247pub(crate) struct CodexToolCallParam {
248    pub prompt: String,
249    pub cwd: PathBuf,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253struct CodexEvent {
254    pub msg: AcpNotification,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258#[serde(tag = "type", rename_all = "snake_case")]
259pub enum AcpNotification {
260    AgentMessage(AgentMessageEvent),
261    #[serde(other)]
262    Other,
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct AgentMessageEvent {
267    pub message: String,
268}