codex.rs

  1use agent_client_protocol as acp;
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use context_server::listener::McpServerTool;
  5use context_server::types::requests;
  6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  7use futures::channel::{mpsc, oneshot};
  8use project::Project;
  9use settings::SettingsStore;
 10use smol::stream::StreamExt as _;
 11use std::cell::RefCell;
 12use std::rc::Rc;
 13use std::{path::Path, sync::Arc};
 14use util::ResultExt;
 15
 16use anyhow::{Context, Result};
 17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 18
 19use crate::mcp_server::ZedMcpServer;
 20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
 21use acp_thread::{AcpThread, AgentConnection};
 22
 23#[derive(Clone)]
 24pub struct Codex;
 25
 26impl AgentServer for Codex {
 27    fn name(&self) -> &'static str {
 28        "Codex"
 29    }
 30
 31    fn empty_state_headline(&self) -> &'static str {
 32        "Welcome to Codex"
 33    }
 34
 35    fn empty_state_message(&self) -> &'static str {
 36        "What can I help with?"
 37    }
 38
 39    fn logo(&self) -> ui::IconName {
 40        ui::IconName::AiOpenAi
 41    }
 42
 43    fn connect(
 44        &self,
 45        _root_dir: &Path,
 46        project: &Entity<Project>,
 47        cx: &mut App,
 48    ) -> Task<Result<Rc<dyn AgentConnection>>> {
 49        let project = project.clone();
 50        let working_directory = project.read(cx).active_project_directory(cx);
 51        cx.spawn(async move |cx| {
 52            let settings = cx.read_global(|settings: &SettingsStore, _| {
 53                settings.get::<AllAgentServersSettings>(None).codex.clone()
 54            })?;
 55
 56            let Some(command) =
 57                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 58            else {
 59                anyhow::bail!("Failed to find codex binary");
 60            };
 61
 62            let client: Arc<ContextServer> = ContextServer::stdio(
 63                ContextServerId("codex-mcp-server".into()),
 64                ContextServerCommand {
 65                    path: command.path,
 66                    args: command.args,
 67                    env: command.env,
 68                },
 69                working_directory,
 70            )
 71            .into();
 72            ContextServer::start(client.clone(), cx).await?;
 73
 74            let (notification_tx, mut notification_rx) = mpsc::unbounded();
 75            client
 76                .client()
 77                .context("Failed to subscribe")?
 78                .on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
 79                    move |notification, _cx| {
 80                        let notification_tx = notification_tx.clone();
 81                        log::trace!(
 82                            "ACP Notification: {}",
 83                            serde_json::to_string_pretty(&notification).unwrap()
 84                        );
 85
 86                        if let Some(notification) =
 87                            serde_json::from_value::<acp::SessionNotification>(notification)
 88                                .log_err()
 89                        {
 90                            notification_tx.unbounded_send(notification).ok();
 91                        }
 92                    }
 93                });
 94
 95            let sessions = Rc::new(RefCell::new(HashMap::default()));
 96
 97            let notification_handler_task = cx.spawn({
 98                let sessions = sessions.clone();
 99                async move |cx| {
100                    while let Some(notification) = notification_rx.next().await {
101                        CodexConnection::handle_session_notification(
102                            notification,
103                            sessions.clone(),
104                            cx,
105                        )
106                    }
107                }
108            });
109
110            let connection = CodexConnection {
111                client,
112                sessions,
113                _notification_handler_task: notification_handler_task,
114            };
115            Ok(Rc::new(connection) as _)
116        })
117    }
118}
119
120struct CodexConnection {
121    client: Arc<context_server::ContextServer>,
122    sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
123    _notification_handler_task: Task<()>,
124}
125
126struct CodexSession {
127    thread: WeakEntity<AcpThread>,
128    cancel_tx: Option<oneshot::Sender<()>>,
129    _mcp_server: ZedMcpServer,
130}
131
132impl AgentConnection for CodexConnection {
133    fn name(&self) -> &'static str {
134        "Codex"
135    }
136
137    fn new_thread(
138        self: Rc<Self>,
139        project: Entity<Project>,
140        cwd: &Path,
141        cx: &mut AsyncApp,
142    ) -> Task<Result<Entity<AcpThread>>> {
143        let client = self.client.client();
144        let sessions = self.sessions.clone();
145        let cwd = cwd.to_path_buf();
146        cx.spawn(async move |cx| {
147            let client = client.context("MCP server is not initialized yet")?;
148            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
149
150            let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
151
152            let response = client
153                .request::<requests::CallTool>(context_server::types::CallToolParams {
154                    name: acp::NEW_SESSION_TOOL_NAME.into(),
155                    arguments: Some(serde_json::to_value(acp::NewSessionArguments {
156                        mcp_servers: [(
157                            mcp_server::SERVER_NAME.to_string(),
158                            mcp_server.server_config()?,
159                        )]
160                        .into(),
161                        client_tools: acp::ClientTools {
162                            request_permission: Some(acp::McpToolId {
163                                mcp_server: mcp_server::SERVER_NAME.into(),
164                                tool_name: mcp_server::RequestPermissionTool::NAME.into(),
165                            }),
166                            read_text_file: Some(acp::McpToolId {
167                                mcp_server: mcp_server::SERVER_NAME.into(),
168                                tool_name: mcp_server::ReadTextFileTool::NAME.into(),
169                            }),
170                            write_text_file: Some(acp::McpToolId {
171                                mcp_server: mcp_server::SERVER_NAME.into(),
172                                tool_name: mcp_server::WriteTextFileTool::NAME.into(),
173                            }),
174                        },
175                        cwd,
176                    })?),
177                    meta: None,
178                })
179                .await?;
180
181            if response.is_error.unwrap_or_default() {
182                return Err(anyhow!(response.text_contents()));
183            }
184
185            let result = serde_json::from_value::<acp::NewSessionOutput>(
186                response.structured_content.context("Empty response")?,
187            )?;
188
189            let thread =
190                cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
191
192            thread_tx.send(thread.downgrade())?;
193
194            let session = CodexSession {
195                thread: thread.downgrade(),
196                cancel_tx: None,
197                _mcp_server: mcp_server,
198            };
199            sessions.borrow_mut().insert(result.session_id, session);
200
201            Ok(thread)
202        })
203    }
204
205    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
206        Task::ready(Err(anyhow!("Authentication not supported")))
207    }
208
209    fn prompt(
210        &self,
211        params: agent_client_protocol::PromptArguments,
212        cx: &mut App,
213    ) -> Task<Result<()>> {
214        let client = self.client.client();
215        let sessions = self.sessions.clone();
216
217        cx.foreground_executor().spawn(async move {
218            let client = client.context("MCP server is not initialized yet")?;
219
220            let (new_cancel_tx, cancel_rx) = oneshot::channel();
221            {
222                let mut sessions = sessions.borrow_mut();
223                let session = sessions
224                    .get_mut(&params.session_id)
225                    .context("Session not found")?;
226                session.cancel_tx.replace(new_cancel_tx);
227            }
228
229            let result = client
230                .request_with::<requests::CallTool>(
231                    context_server::types::CallToolParams {
232                        name: acp::PROMPT_TOOL_NAME.into(),
233                        arguments: Some(serde_json::to_value(params)?),
234                        meta: None,
235                    },
236                    Some(cancel_rx),
237                    None,
238                )
239                .await;
240
241            if let Err(err) = &result
242                && err.is::<context_server::client::RequestCanceled>()
243            {
244                return Ok(());
245            }
246
247            let response = result?;
248
249            if response.is_error.unwrap_or_default() {
250                return Err(anyhow!(response.text_contents()));
251            }
252
253            Ok(())
254        })
255    }
256
257    fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
258        let mut sessions = self.sessions.borrow_mut();
259
260        if let Some(cancel_tx) = sessions
261            .get_mut(session_id)
262            .and_then(|session| session.cancel_tx.take())
263        {
264            cancel_tx.send(()).ok();
265        }
266    }
267}
268
269impl CodexConnection {
270    pub fn handle_session_notification(
271        notification: acp::SessionNotification,
272        threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
273        cx: &mut AsyncApp,
274    ) {
275        let threads = threads.borrow();
276        let Some(thread) = threads
277            .get(&notification.session_id)
278            .and_then(|session| session.thread.upgrade())
279        else {
280            log::error!(
281                "Thread not found for session ID: {}",
282                notification.session_id
283            );
284            return;
285        };
286
287        thread
288            .update(cx, |thread, cx| {
289                thread.handle_session_update(notification.update, cx)
290            })
291            .log_err();
292    }
293}
294
295impl Drop for CodexConnection {
296    fn drop(&mut self) {
297        self.client.stop().log_err();
298    }
299}
300
301#[cfg(test)]
302pub(crate) mod tests {
303    use super::*;
304    use crate::AgentServerCommand;
305    use std::path::Path;
306
307    crate::common_e2e_tests!(Codex, allow_option_id = "approve");
308
309    pub fn local_command() -> AgentServerCommand {
310        let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
311            .join("../../../codex/codex-rs/target/debug/codex");
312
313        AgentServerCommand {
314            path: cli_path,
315            args: vec![],
316            env: None,
317        }
318    }
319}