codex.rs

  1use agent_client_protocol as acp;
  2use anyhow::anyhow;
  3use collections::HashMap;
  4use context_server::listener::McpServerTool;
  5use context_server::types::requests;
  6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  7use futures::channel::{mpsc, oneshot};
  8use project::Project;
  9use settings::SettingsStore;
 10use smol::stream::StreamExt as _;
 11use std::cell::RefCell;
 12use std::rc::Rc;
 13use std::{path::Path, sync::Arc};
 14use util::ResultExt;
 15
 16use anyhow::{Context, Result};
 17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
 18
 19use crate::mcp_server::ZedMcpServer;
 20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
 21use acp_thread::{AcpThread, AgentConnection};
 22
 23#[derive(Clone)]
 24pub struct Codex;
 25
 26impl AgentServer for Codex {
 27    fn name(&self) -> &'static str {
 28        "Codex"
 29    }
 30
 31    fn empty_state_headline(&self) -> &'static str {
 32        "Welcome to Codex"
 33    }
 34
 35    fn empty_state_message(&self) -> &'static str {
 36        "What can I help with?"
 37    }
 38
 39    fn logo(&self) -> ui::IconName {
 40        ui::IconName::AiOpenAi
 41    }
 42
 43    fn connect(
 44        &self,
 45        _root_dir: &Path,
 46        project: &Entity<Project>,
 47        cx: &mut App,
 48    ) -> Task<Result<Rc<dyn AgentConnection>>> {
 49        let project = project.clone();
 50        cx.spawn(async move |cx| {
 51            let settings = cx.read_global(|settings: &SettingsStore, _| {
 52                settings.get::<AllAgentServersSettings>(None).codex.clone()
 53            })?;
 54
 55            let Some(command) =
 56                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 57            else {
 58                anyhow::bail!("Failed to find codex binary");
 59            };
 60
 61            let client: Arc<ContextServer> = ContextServer::stdio(
 62                ContextServerId("codex-mcp-server".into()),
 63                ContextServerCommand {
 64                    path: command.path,
 65                    args: command.args,
 66                    env: command.env,
 67                },
 68            )
 69            .into();
 70            ContextServer::start(client.clone(), cx).await?;
 71
 72            let (notification_tx, mut notification_rx) = mpsc::unbounded();
 73            client
 74                .client()
 75                .context("Failed to subscribe")?
 76                .on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
 77                    move |notification, _cx| {
 78                        let notification_tx = notification_tx.clone();
 79                        log::trace!(
 80                            "ACP Notification: {}",
 81                            serde_json::to_string_pretty(&notification).unwrap()
 82                        );
 83
 84                        if let Some(notification) =
 85                            serde_json::from_value::<acp::SessionNotification>(notification)
 86                                .log_err()
 87                        {
 88                            notification_tx.unbounded_send(notification).ok();
 89                        }
 90                    }
 91                });
 92
 93            let sessions = Rc::new(RefCell::new(HashMap::default()));
 94
 95            let notification_handler_task = cx.spawn({
 96                let sessions = sessions.clone();
 97                async move |cx| {
 98                    while let Some(notification) = notification_rx.next().await {
 99                        CodexConnection::handle_session_notification(
100                            notification,
101                            sessions.clone(),
102                            cx,
103                        )
104                    }
105                }
106            });
107
108            let connection = CodexConnection {
109                client,
110                sessions,
111                _notification_handler_task: notification_handler_task,
112            };
113            Ok(Rc::new(connection) as _)
114        })
115    }
116}
117
118struct CodexConnection {
119    client: Arc<context_server::ContextServer>,
120    sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
121    _notification_handler_task: Task<()>,
122}
123
124struct CodexSession {
125    thread: WeakEntity<AcpThread>,
126    cancel_tx: Option<oneshot::Sender<()>>,
127    _mcp_server: ZedMcpServer,
128}
129
130impl AgentConnection for CodexConnection {
131    fn name(&self) -> &'static str {
132        "Codex"
133    }
134
135    fn new_thread(
136        self: Rc<Self>,
137        project: Entity<Project>,
138        cwd: &Path,
139        cx: &mut AsyncApp,
140    ) -> Task<Result<Entity<AcpThread>>> {
141        let client = self.client.client();
142        let sessions = self.sessions.clone();
143        let cwd = cwd.to_path_buf();
144        cx.spawn(async move |cx| {
145            let client = client.context("MCP server is not initialized yet")?;
146            let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
147
148            let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
149
150            let response = client
151                .request::<requests::CallTool>(context_server::types::CallToolParams {
152                    name: acp::NEW_SESSION_TOOL_NAME.into(),
153                    arguments: Some(serde_json::to_value(acp::NewSessionArguments {
154                        mcp_servers: [(
155                            mcp_server::SERVER_NAME.to_string(),
156                            mcp_server.server_config()?,
157                        )]
158                        .into(),
159                        client_tools: acp::ClientTools {
160                            request_permission: Some(acp::McpToolId {
161                                mcp_server: mcp_server::SERVER_NAME.into(),
162                                tool_name: mcp_server::RequestPermissionTool::NAME.into(),
163                            }),
164                            read_text_file: Some(acp::McpToolId {
165                                mcp_server: mcp_server::SERVER_NAME.into(),
166                                tool_name: mcp_server::ReadTextFileTool::NAME.into(),
167                            }),
168                            write_text_file: Some(acp::McpToolId {
169                                mcp_server: mcp_server::SERVER_NAME.into(),
170                                tool_name: mcp_server::WriteTextFileTool::NAME.into(),
171                            }),
172                        },
173                        cwd,
174                    })?),
175                    meta: None,
176                })
177                .await?;
178
179            if response.is_error.unwrap_or_default() {
180                return Err(anyhow!(response.text_contents()));
181            }
182
183            let result = serde_json::from_value::<acp::NewSessionOutput>(
184                response.structured_content.context("Empty response")?,
185            )?;
186
187            let thread =
188                cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
189
190            thread_tx.send(thread.downgrade())?;
191
192            let session = CodexSession {
193                thread: thread.downgrade(),
194                cancel_tx: None,
195                _mcp_server: mcp_server,
196            };
197            sessions.borrow_mut().insert(result.session_id, session);
198
199            Ok(thread)
200        })
201    }
202
203    fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
204        Task::ready(Err(anyhow!("Authentication not supported")))
205    }
206
207    fn prompt(
208        &self,
209        params: agent_client_protocol::PromptArguments,
210        cx: &mut App,
211    ) -> Task<Result<()>> {
212        let client = self.client.client();
213        let sessions = self.sessions.clone();
214
215        cx.foreground_executor().spawn(async move {
216            let client = client.context("MCP server is not initialized yet")?;
217
218            let (new_cancel_tx, cancel_rx) = oneshot::channel();
219            {
220                let mut sessions = sessions.borrow_mut();
221                let session = sessions
222                    .get_mut(&params.session_id)
223                    .context("Session not found")?;
224                session.cancel_tx.replace(new_cancel_tx);
225            }
226
227            let result = client
228                .request_with::<requests::CallTool>(
229                    context_server::types::CallToolParams {
230                        name: acp::PROMPT_TOOL_NAME.into(),
231                        arguments: Some(serde_json::to_value(params)?),
232                        meta: None,
233                    },
234                    Some(cancel_rx),
235                    None,
236                )
237                .await;
238
239            if let Err(err) = &result
240                && err.is::<context_server::client::RequestCanceled>()
241            {
242                return Ok(());
243            }
244
245            let response = result?;
246
247            if response.is_error.unwrap_or_default() {
248                return Err(anyhow!(response.text_contents()));
249            }
250
251            Ok(())
252        })
253    }
254
255    fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
256        let mut sessions = self.sessions.borrow_mut();
257
258        if let Some(cancel_tx) = sessions
259            .get_mut(session_id)
260            .and_then(|session| session.cancel_tx.take())
261        {
262            cancel_tx.send(()).ok();
263        }
264    }
265}
266
267impl CodexConnection {
268    pub fn handle_session_notification(
269        notification: acp::SessionNotification,
270        threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
271        cx: &mut AsyncApp,
272    ) {
273        let threads = threads.borrow();
274        let Some(thread) = threads
275            .get(&notification.session_id)
276            .and_then(|session| session.thread.upgrade())
277        else {
278            log::error!(
279                "Thread not found for session ID: {}",
280                notification.session_id
281            );
282            return;
283        };
284
285        thread
286            .update(cx, |thread, cx| {
287                thread.handle_session_update(notification.update, cx)
288            })
289            .log_err();
290    }
291}
292
293impl Drop for CodexConnection {
294    fn drop(&mut self) {
295        self.client.stop().log_err();
296    }
297}
298
299#[cfg(test)]
300pub(crate) mod tests {
301    use super::*;
302    use crate::AgentServerCommand;
303    use std::path::Path;
304
305    crate::common_e2e_tests!(Codex);
306
307    pub fn local_command() -> AgentServerCommand {
308        let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
309            .join("../../../codex/codex-rs/target/debug/codex");
310
311        AgentServerCommand {
312            path: cli_path,
313            args: vec!["mcp".into()],
314            env: None,
315        }
316    }
317}