codex.rs

  1use agent_client_protocol as acp;
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use context_server::listener::McpServerTool;
  5use context_server::types::requests;
  6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  7use futures::channel::{mpsc, oneshot};
  8use project::Project;
  9use settings::SettingsStore;
 10use smol::stream::StreamExt as _;
 11use std::cell::RefCell;
 12use std::rc::Rc;
 13use std::{path::Path, sync::Arc};
 14use util::ResultExt;
 15
 16use anyhow::{Context, Result};
 17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 18
 19use crate::mcp_server::ZedMcpServer;
 20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
 21use acp_thread::{AcpThread, AgentConnection};
 22
 23#[derive(Clone)]
 24pub struct Codex;
 25
 26impl AgentServer for Codex {
 27    fn name(&self) -> &'static str {
 28        "Codex"
 29    }
 30
 31    fn empty_state_headline(&self) -> &'static str {
 32        "Welcome to Codex"
 33    }
 34
 35    fn empty_state_message(&self) -> &'static str {
 36        "What can I help with?"
 37    }
 38
 39    fn logo(&self) -> ui::IconName {
 40        ui::IconName::AiOpenAi
 41    }
 42
 43    fn connect(
 44        &self,
 45        _root_dir: &Path,
 46        project: &Entity<Project>,
 47        cx: &mut App,
 48    ) -> Task<Result<Rc<dyn AgentConnection>>> {
 49        let project = project.clone();
 50        cx.spawn(async move |cx| {
 51            let settings = cx.read_global(|settings: &SettingsStore, _| {
 52                settings.get::<AllAgentServersSettings>(None).codex.clone()
 53            })?;
 54
 55            let Some(command) =
 56                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 57            else {
 58                anyhow::bail!("Failed to find codex binary");
 59            };
 60
 61            let client: Arc<ContextServer> = ContextServer::stdio(
 62                ContextServerId("codex-mcp-server".into()),
 63                ContextServerCommand {
 64                    path: command.path,
 65                    args: command.args,
 66                    env: command.env,
 67                },
 68            )
 69            .into();
 70            ContextServer::start(client.clone(), cx).await?;
 71
 72            let (notification_tx, mut notification_rx) = mpsc::unbounded();
 73            client
 74                .client()
 75                .context("Failed to subscribe")?
 76                .on_notification(acp::AGENT_METHODS.session_update, {
 77                    move |notification, _cx| {
 78                        let notification_tx = notification_tx.clone();
 79                        log::trace!(
 80                            "ACP Notification: {}",
 81                            serde_json::to_string_pretty(&notification).unwrap()
 82                        );
 83
 84                        if let Some(notification) =
 85                            serde_json::from_value::<acp::SessionNotification>(notification)
 86                                .log_err()
 87                        {
 88                            notification_tx.unbounded_send(notification).ok();
 89                        }
 90                    }
 91                });
 92
 93            let sessions = Rc::new(RefCell::new(HashMap::default()));
 94
 95            let notification_handler_task = cx.spawn({
 96                let sessions = sessions.clone();
 97                async move |cx| {
 98                    while let Some(notification) = notification_rx.next().await {
 99                        CodexConnection::handle_session_notification(
100                            notification,
101                            sessions.clone(),
102                            cx,
103                        )
104                    }
105                }
106            });
107
108            let connection = CodexConnection {
109                client,
110                sessions,
111                _notification_handler_task: notification_handler_task,
112            };
113            Ok(Rc::new(connection) as _)
114        })
115    }
116}
117
118struct CodexConnection {
119    client: Arc<context_server::ContextServer>,
120    sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
121    _notification_handler_task: Task<()>,
122}
123
124struct CodexSession {
125    thread: WeakEntity<AcpThread>,
126    cancel_tx: Option<oneshot::Sender<()>>,
127    _mcp_server: ZedMcpServer,
128}
129
130impl AgentConnection for CodexConnection {
131    fn name(&self) -> &'static str {
132        "Codex"
133    }
134
135    fn new_thread(
136        self: Rc<Self>,
137        project: Entity<Project>,
138        cwd: &Path,
139        cx: &mut AsyncApp,
140    ) -> Task<Result<Entity<AcpThread>>> {
141        let client = self.client.client();
142        let sessions = self.sessions.clone();
143        let cwd = cwd.to_path_buf();
144        cx.spawn(async move |cx| {
145            let client = client.context("MCP server is not initialized yet")?;
146            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
147
148            let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
149
150            let response = client
151                .request::<requests::CallTool>(context_server::types::CallToolParams {
152                    name: acp::AGENT_METHODS.new_session.into(),
153                    arguments: Some(serde_json::to_value(acp::NewSessionArguments {
154                        mcp_servers: vec![mcp_server.server_config()?],
155                        client_tools: acp::ClientTools {
156                            request_permission: Some(acp::McpToolId {
157                                mcp_server: mcp_server::SERVER_NAME.into(),
158                                tool_name: mcp_server::RequestPermissionTool::NAME.into(),
159                            }),
160                            read_text_file: Some(acp::McpToolId {
161                                mcp_server: mcp_server::SERVER_NAME.into(),
162                                tool_name: mcp_server::ReadTextFileTool::NAME.into(),
163                            }),
164                            write_text_file: Some(acp::McpToolId {
165                                mcp_server: mcp_server::SERVER_NAME.into(),
166                                tool_name: mcp_server::WriteTextFileTool::NAME.into(),
167                            }),
168                        },
169                        cwd,
170                    })?),
171                    meta: None,
172                })
173                .await?;
174
175            if response.is_error.unwrap_or_default() {
176                return Err(anyhow!(response.text_contents()));
177            }
178
179            let result = serde_json::from_value::<acp::NewSessionOutput>(
180                response.structured_content.context("Empty response")?,
181            )?;
182
183            let thread =
184                cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
185
186            thread_tx.send(thread.downgrade())?;
187
188            let session = CodexSession {
189                thread: thread.downgrade(),
190                cancel_tx: None,
191                _mcp_server: mcp_server,
192            };
193            sessions.borrow_mut().insert(result.session_id, session);
194
195            Ok(thread)
196        })
197    }
198
199    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
200        Task::ready(Err(anyhow!("Authentication not supported")))
201    }
202
203    fn prompt(
204        &self,
205        params: agent_client_protocol::PromptArguments,
206        cx: &mut App,
207    ) -> Task<Result<()>> {
208        let client = self.client.client();
209        let sessions = self.sessions.clone();
210
211        cx.foreground_executor().spawn(async move {
212            let client = client.context("MCP server is not initialized yet")?;
213
214            let (new_cancel_tx, cancel_rx) = oneshot::channel();
215            {
216                let mut sessions = sessions.borrow_mut();
217                let session = sessions
218                    .get_mut(&params.session_id)
219                    .context("Session not found")?;
220                session.cancel_tx.replace(new_cancel_tx);
221            }
222
223            let result = client
224                .request_with::<requests::CallTool>(
225                    context_server::types::CallToolParams {
226                        name: acp::AGENT_METHODS.prompt.into(),
227                        arguments: Some(serde_json::to_value(params)?),
228                        meta: None,
229                    },
230                    Some(cancel_rx),
231                    None,
232                )
233                .await;
234
235            if let Err(err) = &result
236                && err.is::<context_server::client::RequestCanceled>()
237            {
238                return Ok(());
239            }
240
241            let response = result?;
242
243            if response.is_error.unwrap_or_default() {
244                return Err(anyhow!(response.text_contents()));
245            }
246
247            Ok(())
248        })
249    }
250
251    fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
252        let mut sessions = self.sessions.borrow_mut();
253
254        if let Some(cancel_tx) = sessions
255            .get_mut(session_id)
256            .and_then(|session| session.cancel_tx.take())
257        {
258            cancel_tx.send(()).ok();
259        }
260    }
261}
262
263impl CodexConnection {
264    pub fn handle_session_notification(
265        notification: acp::SessionNotification,
266        threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
267        cx: &mut AsyncApp,
268    ) {
269        let threads = threads.borrow();
270        let Some(thread) = threads
271            .get(&notification.session_id)
272            .and_then(|session| session.thread.upgrade())
273        else {
274            log::error!(
275                "Thread not found for session ID: {}",
276                notification.session_id
277            );
278            return;
279        };
280
281        thread
282            .update(cx, |thread, cx| {
283                thread.handle_session_update(notification.update, cx)
284            })
285            .log_err();
286    }
287}
288
289impl Drop for CodexConnection {
290    fn drop(&mut self) {
291        self.client.stop().log_err();
292    }
293}
294
295#[cfg(test)]
296pub(crate) mod tests {
297    use super::*;
298    use crate::AgentServerCommand;
299    use std::path::Path;
300
301    crate::common_e2e_tests!(Codex, allow_option_id = "approve");
302
303    pub fn local_command() -> AgentServerCommand {
304        let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
305            .join("../../../codex/codex-rs/target/debug/codex");
306
307        AgentServerCommand {
308            path: cli_path,
309            args: vec!["mcp".into()],
310            env: None,
311        }
312    }
313}