codex.rs

  1use collections::HashMap;
  2use context_server::types::requests::CallTool;
  3use context_server::types::{CallToolParams, ToolResponseContent};
  4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  5use futures::channel::{mpsc, oneshot};
  6use itertools::Itertools;
  7use project::Project;
  8use serde::de::DeserializeOwned;
  9use settings::SettingsStore;
 10use smol::stream::StreamExt;
 11use std::cell::RefCell;
 12use std::path::{Path, PathBuf};
 13use std::rc::Rc;
 14use std::sync::Arc;
 15
 16use agentic_coding_protocol::{
 17    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
 18};
 19use anyhow::{Context, Result, anyhow};
 20use futures::future::LocalBoxFuture;
 21use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 22use gpui::{App, AppContext, Entity, Task};
 23use serde::{Deserialize, Serialize};
 24use util::ResultExt;
 25
 26use crate::mcp_server::{McpConfig, ZedMcpServer};
 27use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 28use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 29
 30#[derive(Clone)]
 31pub struct Codex;
 32
 33pub struct CodexApproval;
 34impl context_server::types::Request for CodexApproval {
 35    type Params = CodexApprovalRequest;
 36    type Response = CodexApprovalResponse;
 37    const METHOD: &'static str = "elicitation/create";
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct CodexApprovalRequest {
 42    // These fields are required so that `params`
 43    // conforms to ElicitRequestParams.
 44    pub message: String,
 45    // #[serde(rename = "requestedSchema")]
 46    // pub requested_schema: ElicitRequestParamsRequestedSchema,
 47
 48    // // These are additional fields the client can use to
 49    // // correlate the request with the codex tool call.
 50    // pub codex_elicitation: String,
 51    // pub codex_mcp_tool_call_id: String,
 52    // pub codex_event_id: String,
 53    // pub codex_command: Vec<String>,
 54    // pub codex_cwd: PathBuf,
 55}
 56
 57#[derive(Debug, Serialize, Deserialize)]
 58pub struct CodexApprovalResponse {
 59    pub decision: ReviewDecision,
 60}
 61
 62/// User's decision in response to an ExecApprovalRequest.
 63#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
 64#[serde(rename_all = "snake_case")]
 65pub enum ReviewDecision {
 66    /// User has approved this command and the agent should execute it.
 67    Approved,
 68
 69    /// User has approved this command and wants to automatically approve any
 70    /// future identical instances (`command` and `cwd` match exactly) for the
 71    /// remainder of the session.
 72    ApprovedForSession,
 73
 74    /// User has denied this command and the agent should not execute it, but
 75    /// it should continue the session and try something else.
 76    #[default]
 77    Denied,
 78
 79    /// User has denied this command and the agent should not do anything until
 80    /// the user's next command.
 81    Abort,
 82}
 83
 84impl AgentServer for Codex {
 85    fn name(&self) -> &'static str {
 86        "Codex"
 87    }
 88
 89    fn empty_state_headline(&self) -> &'static str {
 90        self.name()
 91    }
 92
 93    fn empty_state_message(&self) -> &'static str {
 94        ""
 95    }
 96
 97    fn logo(&self) -> ui::IconName {
 98        ui::IconName::AiOpenAi
 99    }
100
101    fn supports_always_allow(&self) -> bool {
102        false
103    }
104
105    fn new_thread(
106        &self,
107        root_dir: &Path,
108        project: &Entity<Project>,
109        cx: &mut App,
110    ) -> Task<Result<Entity<AcpThread>>> {
111        let project = project.clone();
112        let root_dir = root_dir.to_path_buf();
113        let title = self.name().into();
114        cx.spawn(async move |cx| {
115            let (mut delegate_tx, delegate_rx) = watch::channel(None);
116            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
117
118            let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
119            let mcp_server_config = zed_mcp_server.server_config()?;
120            // https://github.com/openai/codex/blob/main/codex-rs/config.md
121            let cli_server_config = format!(
122                "mcp_servers.{}={{command = \"{}\", args = [{}]}}",
123                crate::mcp_server::SERVER_NAME,
124                mcp_server_config.command.display(),
125                mcp_server_config
126                    .args
127                    .iter()
128                    .map(|arg| format!("\"{}\"", arg))
129                    .join(", ")
130            );
131
132            let settings = cx.read_global(|settings: &SettingsStore, _| {
133                settings.get::<AllAgentServersSettings>(None).codex.clone()
134            })?;
135
136            let Some(mut command) =
137                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
138            else {
139                anyhow::bail!("Failed to find codex binary");
140            };
141
142            command
143                .args
144                .extend(["--config".to_string(), cli_server_config]);
145
146            let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
147                ContextServerId("codex-mcp-server".into()),
148                ContextServerCommand {
149                    path: command.path,
150                    args: command.args,
151                    env: command.env,
152                },
153            )
154            .into();
155
156            ContextServer::start(codex_mcp_client.clone(), cx).await?;
157            // todo! stop
158
159            let (notification_tx, mut notification_rx) = mpsc::unbounded();
160
161            let client = codex_mcp_client
162                .client()
163                .context("Failed to subscribe to server")?;
164            client.on_request::<CodexApproval, _>({
165                move |elicitation: CodexApprovalRequest, cx| {
166                    cx.spawn(async move |cx| anyhow::bail!("oops"))
167                }
168            });
169            client.on_notification("codex/event", {
170                move |event, cx| {
171                    let mut notification_tx = notification_tx.clone();
172                    cx.background_spawn(async move {
173                        log::trace!("Notification: {:?}", event);
174                        if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
175                            notification_tx.send(event.msg).await.log_err();
176                        }
177                    })
178                    .detach();
179                }
180            });
181
182            cx.new(|cx| {
183                let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
184                delegate_tx.send(Some(delegate.clone())).log_err();
185
186                let handler_task = cx.spawn({
187                    let delegate = delegate.clone();
188                    let tool_id_map = tool_id_map.clone();
189                    async move |_, _cx| {
190                        while let Some(notification) = notification_rx.next().await {
191                            CodexAgentConnection::handle_acp_notification(
192                                &delegate,
193                                notification,
194                                &tool_id_map,
195                            )
196                            .await
197                            .log_err();
198                        }
199                    }
200                });
201
202                let connection = CodexAgentConnection {
203                    root_dir,
204                    codex_mcp: codex_mcp_client,
205                    cancel_request_tx: Default::default(),
206                    tool_id_map: tool_id_map.clone(),
207                    _handler_task: handler_task,
208                    _zed_mcp: zed_mcp_server,
209                };
210
211                acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
212            })
213        })
214    }
215}
216
217impl AgentConnection for CodexAgentConnection {
218    /// Send a request to the agent and wait for a response.
219    fn request_any(
220        &self,
221        params: AnyAgentRequest,
222    ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
223        let client = self.codex_mcp.client();
224        let root_dir = self.root_dir.clone();
225        let cancel_request_tx = self.cancel_request_tx.clone();
226        async move {
227            let client = client.context("Codex MCP server is not initialized")?;
228
229            match params {
230                // todo: consider sending an empty request so we get the init response?
231                AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
232                    acp::InitializeResponse {
233                        is_authenticated: true,
234                        protocol_version: ProtocolVersion::latest(),
235                    },
236                )),
237                AnyAgentRequest::AuthenticateParams(_) => {
238                    Err(anyhow!("Authentication not supported"))
239                }
240                AnyAgentRequest::SendUserMessageParams(message) => {
241                    let (new_cancel_tx, cancel_rx) = oneshot::channel();
242                    cancel_request_tx.borrow_mut().replace(new_cancel_tx);
243
244                    client
245                        .cancellable_request::<CallTool>(
246                            CallToolParams {
247                                name: "codex".into(),
248                                arguments: Some(serde_json::to_value(CodexToolCallParam {
249                                    prompt: message
250                                        .chunks
251                                        .into_iter()
252                                        .filter_map(|chunk| match chunk {
253                                            acp::UserMessageChunk::Text { text } => Some(text),
254                                            acp::UserMessageChunk::Path { .. } => {
255                                                // todo!
256                                                None
257                                            }
258                                        })
259                                        .collect(),
260                                    cwd: root_dir,
261                                })?),
262                                meta: None,
263                            },
264                            cancel_rx,
265                        )
266                        .await?;
267
268                    Ok(AnyAgentResult::SendUserMessageResponse(
269                        acp::SendUserMessageResponse,
270                    ))
271                }
272                AnyAgentRequest::CancelSendMessageParams(_) => {
273                    if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
274                        if let Some(cancel_tx) = borrow.take() {
275                            cancel_tx.send(()).ok();
276                        }
277                    }
278
279                    Ok(AnyAgentResult::CancelSendMessageResponse(
280                        acp::CancelSendMessageResponse,
281                    ))
282                }
283            }
284        }
285        .boxed_local()
286    }
287}
288
289struct CodexAgentConnection {
290    codex_mcp: Arc<context_server::ContextServer>,
291    root_dir: PathBuf,
292    cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
293    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
294    _handler_task: Task<()>,
295    _zed_mcp: ZedMcpServer,
296}
297
298impl CodexAgentConnection {
299    async fn handle_acp_notification(
300        delegate: &AcpClientDelegate,
301        event: AcpNotification,
302        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
303    ) -> Result<()> {
304        match event {
305            AcpNotification::AgentMessage(message) => {
306                delegate
307                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
308                        chunk: acp::AssistantMessageChunk::Text {
309                            text: message.message,
310                        },
311                    })
312                    .await?;
313            }
314            AcpNotification::AgentReasoning(message) => {
315                delegate
316                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
317                        chunk: acp::AssistantMessageChunk::Thought {
318                            thought: message.text,
319                        },
320                    })
321                    .await?
322            }
323            AcpNotification::McpToolCallBegin(event) => {
324                let result = delegate
325                    .push_tool_call(acp::PushToolCallParams {
326                        label: format!("`{}: {}`", event.server, event.tool),
327                        icon: acp::Icon::Hammer,
328                        content: event.arguments.and_then(|args| {
329                            Some(acp::ToolCallContent::Markdown {
330                                markdown: md_codeblock(
331                                    "json",
332                                    &serde_json::to_string_pretty(&args).ok()?,
333                                ),
334                            })
335                        }),
336                        locations: vec![],
337                    })
338                    .await?;
339
340                tool_id_map.borrow_mut().insert(event.call_id, result.id);
341            }
342            AcpNotification::McpToolCallEnd(event) => {
343                let acp_call_id = tool_id_map
344                    .borrow_mut()
345                    .remove(&event.call_id)
346                    .context("Missing tool call")?;
347
348                let (status, content) = match event.result {
349                    Ok(value) => {
350                        if let Ok(response) =
351                            serde_json::from_value::<context_server::types::CallToolResponse>(value)
352                        {
353                            (
354                                acp::ToolCallStatus::Finished,
355                                mcp_tool_content_to_acp(response.content),
356                            )
357                        } else {
358                            (
359                                acp::ToolCallStatus::Error,
360                                Some(acp::ToolCallContent::Markdown {
361                                    markdown: "Failed to parse tool response".to_string(),
362                                }),
363                            )
364                        }
365                    }
366                    Err(error) => (
367                        acp::ToolCallStatus::Error,
368                        Some(acp::ToolCallContent::Markdown { markdown: error }),
369                    ),
370                };
371
372                delegate
373                    .update_tool_call(acp::UpdateToolCallParams {
374                        tool_call_id: acp_call_id,
375                        status,
376                        content,
377                    })
378                    .await?;
379            }
380            AcpNotification::ExecCommandBegin(event) => {
381                let inner_command = strip_bash_lc_and_escape(&event.command);
382
383                let result = delegate
384                    .push_tool_call(acp::PushToolCallParams {
385                        label: format!("`{}`", inner_command),
386                        icon: acp::Icon::Terminal,
387                        content: None,
388                        locations: vec![],
389                    })
390                    .await?;
391
392                tool_id_map.borrow_mut().insert(event.call_id, result.id);
393            }
394            AcpNotification::ExecCommandEnd(event) => {
395                let acp_call_id = tool_id_map
396                    .borrow_mut()
397                    .remove(&event.call_id)
398                    .context("Missing tool call")?;
399
400                let mut content = String::new();
401                if !event.stdout.is_empty() {
402                    use std::fmt::Write;
403                    writeln!(
404                        &mut content,
405                        "### Output\n\n{}",
406                        md_codeblock("", &event.stdout)
407                    )
408                    .unwrap();
409                }
410                if !event.stdout.is_empty() && !event.stderr.is_empty() {
411                    use std::fmt::Write;
412                    writeln!(&mut content).unwrap();
413                }
414                if !event.stderr.is_empty() {
415                    use std::fmt::Write;
416                    writeln!(
417                        &mut content,
418                        "### Error\n\n{}",
419                        md_codeblock("", &event.stderr)
420                    )
421                    .unwrap();
422                }
423                let success = event.exit_code == 0;
424                if !success {
425                    use std::fmt::Write;
426                    writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
427                }
428
429                delegate
430                    .update_tool_call(acp::UpdateToolCallParams {
431                        tool_call_id: acp_call_id,
432                        status: if success {
433                            acp::ToolCallStatus::Finished
434                        } else {
435                            acp::ToolCallStatus::Error
436                        },
437                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
438                    })
439                    .await?;
440            }
441            AcpNotification::ExecApprovalRequest(event) => {
442                let inner_command = strip_bash_lc_and_escape(&event.command);
443                let root_command = inner_command
444                    .split(" ")
445                    .next()
446                    .map(|s| s.to_string())
447                    .unwrap_or_default();
448
449                let response = delegate
450                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
451                        tool_call: acp::PushToolCallParams {
452                            label: format!("`{}`", inner_command),
453                            icon: acp::Icon::Terminal,
454                            content: None,
455                            locations: vec![],
456                        },
457                        confirmation: acp::ToolCallConfirmation::Execute {
458                            command: inner_command,
459                            root_command,
460                            description: event.reason,
461                        },
462                    })
463                    .await?;
464
465                tool_id_map.borrow_mut().insert(event.call_id, response.id);
466
467                // todo! approval
468            }
469            AcpNotification::Other => {}
470        }
471
472        Ok(())
473    }
474}
475
476/// todo! use types from h2a crate when we have one
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
479#[serde(rename_all = "kebab-case")]
480pub(crate) struct CodexToolCallParam {
481    pub prompt: String,
482    pub cwd: PathBuf,
483}
484
485#[derive(Debug, Clone, Serialize, Deserialize)]
486struct CodexEvent {
487    pub msg: AcpNotification,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
491#[serde(tag = "type", rename_all = "snake_case")]
492pub enum AcpNotification {
493    AgentMessage(AgentMessageEvent),
494    AgentReasoning(AgentReasoningEvent),
495    McpToolCallBegin(McpToolCallBeginEvent),
496    McpToolCallEnd(McpToolCallEndEvent),
497    ExecCommandBegin(ExecCommandBeginEvent),
498    ExecCommandEnd(ExecCommandEndEvent),
499    ExecApprovalRequest(ExecApprovalRequestEvent),
500    #[serde(other)]
501    Other,
502}
503
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct AgentMessageEvent {
506    pub message: String,
507}
508
509#[derive(Debug, Clone, Deserialize, Serialize)]
510pub struct AgentReasoningEvent {
511    pub text: String,
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct McpToolCallBeginEvent {
516    pub call_id: String,
517    pub server: String,
518    pub tool: String,
519    pub arguments: Option<serde_json::Value>,
520}
521
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct McpToolCallEndEvent {
524    pub call_id: String,
525    pub result: Result<serde_json::Value, String>,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct ExecCommandBeginEvent {
530    pub call_id: String,
531    pub command: Vec<String>,
532    pub cwd: PathBuf,
533}
534
535#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct ExecCommandEndEvent {
537    pub call_id: String,
538    pub stdout: String,
539    pub stderr: String,
540    pub exit_code: i32,
541}
542
543#[derive(Debug, Clone, Serialize, Deserialize)]
544pub struct ExecApprovalRequestEvent {
545    pub call_id: String,
546    pub command: Vec<String>,
547    pub cwd: PathBuf,
548    #[serde(skip_serializing_if = "Option::is_none")]
549    pub reason: Option<String>,
550}
551
552// Helper functions
553fn md_codeblock(lang: &str, content: &str) -> String {
554    if content.ends_with('\n') {
555        format!("```{}\n{}```", lang, content)
556    } else {
557        format!("```{}\n{}\n```", lang, content)
558    }
559}
560
561fn strip_bash_lc_and_escape(command: &[String]) -> String {
562    match command {
563        // exactly three items
564        [first, second, third]
565            // first two must be "bash", "-lc"
566            if first == "bash" && second == "-lc" =>
567        {
568            third.clone()
569        }
570        _ => escape_command(command),
571    }
572}
573
574fn escape_command(command: &[String]) -> String {
575    shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
576}
577
578fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
579    let mut content = String::new();
580
581    for chunk in chunks {
582        match chunk {
583            ToolResponseContent::Text { text } => content.push_str(&text),
584            ToolResponseContent::Image { .. } => {
585                // todo!
586            }
587            ToolResponseContent::Audio { .. } => {
588                // todo!
589            }
590            ToolResponseContent::Resource { .. } => {
591                // todo!
592            }
593        }
594    }
595
596    if !content.is_empty() {
597        Some(acp::ToolCallContent::Markdown { markdown: content })
598    } else {
599        None
600    }
601}
602
603#[cfg(test)]
604pub mod tests {
605    use super::*;
606
607    crate::common_e2e_tests!(Codex);
608
609    pub fn local_command() -> AgentServerCommand {
610        let cli_path =
611            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/code-rs/target/debug/codex");
612
613        AgentServerCommand {
614            path: cli_path,
615            args: vec!["mcp".into()],
616            env: None,
617        }
618    }
619}