codex.rs

  1use collections::HashMap;
  2use context_server::types::requests::CallTool;
  3use context_server::types::{CallToolParams, ToolResponseContent};
  4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  5use futures::channel::{mpsc, oneshot};
  6use itertools::Itertools;
  7use project::Project;
  8use serde::de::DeserializeOwned;
  9use settings::SettingsStore;
 10use smol::stream::StreamExt;
 11use std::cell::RefCell;
 12use std::path::{Path, PathBuf};
 13use std::rc::Rc;
 14use std::sync::Arc;
 15
 16use agentic_coding_protocol::{
 17    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
 18};
 19use anyhow::{Context, Result, anyhow};
 20use futures::future::LocalBoxFuture;
 21use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 22use gpui::{App, AppContext, Entity, Task};
 23use serde::{Deserialize, Serialize};
 24use util::ResultExt;
 25
 26use crate::mcp_server::{McpConfig, ZedMcpServer};
 27use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 28use acp_thread::{AcpThread, AgentConnection, OldAcpClientDelegate};
 29
 30#[derive(Clone)]
 31pub struct Codex;
 32
 33pub struct CodexApproval;
 34impl context_server::types::Request for CodexApproval {
 35    type Params = CodexElicitation;
 36    type Response = CodexApprovalResponse;
 37    const METHOD: &'static str = "elicitation/create";
 38}
 39
 40#[derive(Debug, Serialize, Deserialize)]
 41pub struct ExecApprovalRequest {
 42    // These fields are required so that `params`
 43    // conforms to ElicitRequestParams.
 44    pub message: String,
 45    // #[serde(rename = "requestedSchema")]
 46    // pub requested_schema: ElicitRequestParamsRequestedSchema,
 47
 48    // // These are additional fields the client can use to
 49    // // correlate the request with the codex tool call.
 50    pub codex_mcp_tool_call_id: String,
 51    // pub codex_event_id: String,
 52    pub codex_command: Vec<String>,
 53    pub codex_cwd: PathBuf,
 54}
 55
 56#[derive(Debug, Serialize, Deserialize)]
 57pub struct PatchApprovalRequest {
 58    pub message: String,
 59    // #[serde(rename = "requestedSchema")]
 60    // pub requested_schema: ElicitRequestParamsRequestedSchema,
 61    pub codex_mcp_tool_call_id: String,
 62    pub codex_event_id: String,
 63    #[serde(skip_serializing_if = "Option::is_none")]
 64    pub codex_reason: Option<String>,
 65    #[serde(skip_serializing_if = "Option::is_none")]
 66    pub codex_grant_root: Option<PathBuf>,
 67    pub codex_changes: HashMap<PathBuf, FileChange>,
 68}
 69
 70#[derive(Debug, Serialize, Deserialize)]
 71#[serde(tag = "codex_elicitation", rename_all = "snake_case")]
 72enum CodexElicitation {
 73    ExecApproval(ExecApprovalRequest),
 74    PatchApproval(PatchApprovalRequest),
 75}
 76
 77#[derive(Debug, Clone, Deserialize, Serialize)]
 78#[serde(rename_all = "snake_case")]
 79pub enum FileChange {
 80    Add {
 81        content: String,
 82    },
 83    Delete,
 84    Update {
 85        unified_diff: String,
 86        move_path: Option<PathBuf>,
 87    },
 88}
 89
 90#[derive(Debug, Serialize, Deserialize)]
 91pub struct CodexApprovalResponse {
 92    pub decision: ReviewDecision,
 93}
 94
 95/// User's decision in response to an ExecApprovalRequest.
 96#[derive(Debug, Default, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
 97#[serde(rename_all = "snake_case")]
 98pub enum ReviewDecision {
 99    /// User has approved this command and the agent should execute it.
100    Approved,
101
102    /// User has approved this command and wants to automatically approve any
103    /// future identical instances (`command` and `cwd` match exactly) for the
104    /// remainder of the session.
105    ApprovedForSession,
106
107    /// User has denied this command and the agent should not execute it, but
108    /// it should continue the session and try something else.
109    #[default]
110    Denied,
111
112    /// User has denied this command and the agent should not do anything until
113    /// the user's next command.
114    Abort,
115}
116
117impl AgentServer for Codex {
118    fn name(&self) -> &'static str {
119        "Codex"
120    }
121
122    fn empty_state_headline(&self) -> &'static str {
123        self.name()
124    }
125
126    fn empty_state_message(&self) -> &'static str {
127        ""
128    }
129
130    fn logo(&self) -> ui::IconName {
131        ui::IconName::AiOpenAi
132    }
133
134    fn supports_always_allow(&self) -> bool {
135        false
136    }
137
138    fn new_thread(
139        &self,
140        root_dir: &Path,
141        project: &Entity<Project>,
142        cx: &mut App,
143    ) -> Task<Result<Entity<AcpThread>>> {
144        let project = project.clone();
145        let root_dir = root_dir.to_path_buf();
146        let title = self.name().into();
147        cx.spawn(async move |cx| {
148            let (mut delegate_tx, delegate_rx) = watch::channel(None);
149            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
150
151            let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
152            let mcp_server_config = zed_mcp_server.server_config()?;
153            // https://github.com/openai/codex/blob/main/codex-rs/config.md
154            let cli_server_config = format!(
155                "mcp_servers.{}={{command = \"{}\", args = [{}]}}",
156                crate::mcp_server::SERVER_NAME,
157                mcp_server_config.command.display(),
158                mcp_server_config
159                    .args
160                    .iter()
161                    .map(|arg| format!("\"{}\"", arg))
162                    .join(", ")
163            );
164
165            let settings = cx.read_global(|settings: &SettingsStore, _| {
166                settings.get::<AllAgentServersSettings>(None).codex.clone()
167            })?;
168
169            let Some(mut command) =
170                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
171            else {
172                anyhow::bail!("Failed to find codex binary");
173            };
174
175            command
176                .args
177                .extend(["--config".to_string(), cli_server_config]);
178
179            let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
180                ContextServerId("codex-mcp-server".into()),
181                ContextServerCommand {
182                    path: command.path,
183                    args: command.args,
184                    env: command.env,
185                },
186            )
187            .into();
188
189            ContextServer::start(codex_mcp_client.clone(), cx).await?;
190            // todo! stop
191
192            let (notification_tx, mut notification_rx) = mpsc::unbounded();
193            let (request_tx, mut request_rx) = mpsc::unbounded();
194
195            let client = codex_mcp_client
196                .client()
197                .context("Failed to subscribe to server")?;
198            client.on_notification("codex/event", {
199                move |event, cx| {
200                    let mut notification_tx = notification_tx.clone();
201                    cx.background_spawn(async move {
202                        log::trace!("Notification: {:?}", event);
203                        if let Some(event) = serde_json::from_value::<CodexEvent>(event).log_err() {
204                            notification_tx.send(event.msg).await.log_err();
205                        }
206                    })
207                    .detach();
208                }
209            });
210
211            client.on_request::<CodexApproval, _>({
212                let delegate = delegate.clone();
213                {
214                    move |elicitation, cx| {
215                        let (tx, rx) = oneshot::channel::<Result<CodexApprovalResponse>>();
216                        request_tx.send((elicitation, tx));
217                        cx.foreground_executor().spawn(rx)
218                    }
219                }
220            });
221
222            cx.new(|cx| {
223                let delegate = OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
224                delegate_tx.send(Some(delegate.clone())).log_err();
225
226                let handler_task = cx.spawn({
227                    let delegate = delegate.clone();
228                    let tool_id_map = tool_id_map.clone();
229                    async move |_, _cx| {
230                        while let Some(notification) = notification_rx.next().await {
231                            CodexAgentConnection::handle_acp_notification(
232                                &delegate,
233                                notification,
234                                &tool_id_map,
235                            )
236                            .await
237                            .log_err();
238                        }
239                    }
240                });
241
242                let request_task = cx.spawn({
243                    let delegate = delegate.clone();
244                    let tool_id_map = tool_id_map.clone();
245                    async move |_, _cx| {
246                        while let Some((elicitation, respond)) = request_tx.next().await {
247                            let confirmation = match elicitation {
248                                CodexElicitation::ExecApproval(exec) => {
249                                    let inner_command =
250                                        strip_bash_lc_and_escape(&exec.codex_command);
251
252                                    acp::RequestToolCallConfirmationParams {
253                                        tool_call: acp::PushToolCallParams {
254                                            label: todo!(),
255                                            icon: acp::Icon::Terminal,
256                                            content: None,
257                                            locations: vec![],
258                                        },
259                                        confirmation: acp::ToolCallConfirmation::Execute {
260                                            root_command: inner_command
261                                                .split(" ")
262                                                .next()
263                                                .unwrap_or_default()
264                                                .to_string(),
265                                            command: inner_command,
266                                            description: Some(exec.message),
267                                        },
268                                    }
269                                }
270                                CodexElicitation::PatchApproval(patch) => {
271                                    acp::RequestToolCallConfirmationParams {
272                                        tool_call: acp::PushToolCallParams {
273                                            label: "Edit".to_string(),
274                                            icon: acp::Icon::Pencil,
275                                            content: None, // todo!()
276                                            locations: patch
277                                                .codex_changes
278                                                .keys()
279                                                .map(|path| acp::ToolCallLocation {
280                                                    path: path.clone(),
281                                                    line: None,
282                                                })
283                                                .collect(),
284                                        },
285                                        confirmation: acp::ToolCallConfirmation::Edit {
286                                            description: Some(patch.message),
287                                        },
288                                    }
289                                }
290                            };
291
292                            let task = cx.spawn(async move |cx| {
293                                let response = delegate
294                                    .request_tool_call_confirmation(confirmation)
295                                    .await?;
296
297                                let decision = match response.outcome {
298                                    acp::ToolCallConfirmationOutcome::Allow => {
299                                        ReviewDecision::Approved
300                                    }
301                                    acp::ToolCallConfirmationOutcome::AlwaysAllow
302                                    | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
303                                    | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
304                                        ReviewDecision::ApprovedForSession
305                                    }
306                                    acp::ToolCallConfirmationOutcome::Reject => {
307                                        ReviewDecision::Denied
308                                    }
309                                    acp::ToolCallConfirmationOutcome::Cancel => {
310                                        ReviewDecision::Abort
311                                    }
312                                };
313
314                                Ok(CodexApprovalResponse { decision })
315                            });
316                        }
317
318                        cx.spawn(async move |cx| {
319                            tx.send(task.await).ok();
320                        })
321                    }
322                });
323
324                let connection = CodexAgentConnection {
325                    root_dir,
326                    codex_mcp: codex_mcp_client,
327                    cancel_request_tx: Default::default(),
328                    tool_id_map: tool_id_map.clone(),
329                    _handler_task: handler_task,
330                    _request_task: request_task,
331                    _zed_mcp: zed_mcp_server,
332                };
333
334                acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
335            })
336        })
337    }
338}
339
340impl AgentConnection for CodexAgentConnection {
341    /// Send a request to the agent and wait for a response.
342    fn request_any(
343        &self,
344        params: AnyAgentRequest,
345    ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
346        let client = self.codex_mcp.client();
347        let root_dir = self.root_dir.clone();
348        let cancel_request_tx = self.cancel_request_tx.clone();
349        async move {
350            let client = client.context("Codex MCP server is not initialized")?;
351
352            match params {
353                // todo: consider sending an empty request so we get the init response?
354                AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
355                    acp::InitializeResponse {
356                        is_authenticated: true,
357                        protocol_version: ProtocolVersion::latest(),
358                    },
359                )),
360                AnyAgentRequest::AuthenticateParams(_) => {
361                    Err(anyhow!("Authentication not supported"))
362                }
363                AnyAgentRequest::SendUserMessageParams(message) => {
364                    let (new_cancel_tx, cancel_rx) = oneshot::channel();
365                    cancel_request_tx.borrow_mut().replace(new_cancel_tx);
366
367                    client
368                        .cancellable_request::<CallTool>(
369                            CallToolParams {
370                                name: "codex".into(),
371                                arguments: Some(serde_json::to_value(CodexToolCallParam {
372                                    prompt: message
373                                        .chunks
374                                        .into_iter()
375                                        .filter_map(|chunk| match chunk {
376                                            acp::UserMessageChunk::Text { text } => Some(text),
377                                            acp::UserMessageChunk::Path { .. } => {
378                                                // todo!
379                                                None
380                                            }
381                                        })
382                                        .collect(),
383                                    cwd: root_dir,
384                                })?),
385                                meta: None,
386                            },
387                            cancel_rx,
388                        )
389                        .await?;
390
391                    Ok(AnyAgentResult::SendUserMessageResponse(
392                        acp::SendUserMessageResponse,
393                    ))
394                }
395                AnyAgentRequest::CancelSendMessageParams(_) => {
396                    if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
397                        if let Some(cancel_tx) = borrow.take() {
398                            cancel_tx.send(()).ok();
399                        }
400                    }
401
402                    Ok(AnyAgentResult::CancelSendMessageResponse(
403                        acp::CancelSendMessageResponse,
404                    ))
405                }
406            }
407        }
408        .boxed_local()
409    }
410}
411
412struct CodexAgentConnection {
413    codex_mcp: Arc<context_server::ContextServer>,
414    root_dir: PathBuf,
415    cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
416    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
417    _handler_task: Task<()>,
418    _request_task: Task<()>,
419    _zed_mcp: ZedMcpServer,
420}
421
422impl CodexAgentConnection {
423    async fn handle_acp_notification(
424        delegate: &OldAcpClientDelegate,
425        event: AcpNotification,
426        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
427    ) -> Result<()> {
428        match event {
429            AcpNotification::AgentMessage(message) => {
430                delegate
431                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
432                        chunk: acp::AssistantMessageChunk::Text {
433                            text: message.message,
434                        },
435                    })
436                    .await?;
437            }
438            AcpNotification::AgentReasoning(message) => {
439                delegate
440                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
441                        chunk: acp::AssistantMessageChunk::Thought {
442                            thought: message.text,
443                        },
444                    })
445                    .await?
446            }
447            AcpNotification::McpToolCallBegin(event) => {
448                let result = delegate
449                    .push_tool_call(acp::PushToolCallParams {
450                        label: format!("`{}: {}`", event.server, event.tool),
451                        icon: acp::Icon::Hammer,
452                        content: event.arguments.and_then(|args| {
453                            Some(acp::ToolCallContent::Markdown {
454                                markdown: md_codeblock(
455                                    "json",
456                                    &serde_json::to_string_pretty(&args).ok()?,
457                                ),
458                            })
459                        }),
460                        locations: vec![],
461                    })
462                    .await?;
463
464                tool_id_map.borrow_mut().insert(event.call_id, result.id);
465            }
466            AcpNotification::McpToolCallEnd(event) => {
467                let acp_call_id = tool_id_map
468                    .borrow_mut()
469                    .remove(&event.call_id)
470                    .context("Missing tool call")?;
471
472                let (status, content) = match event.result {
473                    Ok(value) => {
474                        if let Ok(response) =
475                            serde_json::from_value::<context_server::types::CallToolResponse>(value)
476                        {
477                            (
478                                acp::ToolCallStatus::Finished,
479                                mcp_tool_content_to_acp(response.content),
480                            )
481                        } else {
482                            (
483                                acp::ToolCallStatus::Error,
484                                Some(acp::ToolCallContent::Markdown {
485                                    markdown: "Failed to parse tool response".to_string(),
486                                }),
487                            )
488                        }
489                    }
490                    Err(error) => (
491                        acp::ToolCallStatus::Error,
492                        Some(acp::ToolCallContent::Markdown { markdown: error }),
493                    ),
494                };
495
496                delegate
497                    .update_tool_call(acp::UpdateToolCallParams {
498                        tool_call_id: acp_call_id,
499                        status,
500                        content,
501                    })
502                    .await?;
503            }
504            AcpNotification::ExecCommandBegin(event) => {
505                let inner_command = strip_bash_lc_and_escape(&event.command);
506
507                let result = delegate
508                    .push_tool_call(acp::PushToolCallParams {
509                        label: format!("`{}`", inner_command),
510                        icon: acp::Icon::Terminal,
511                        content: None,
512                        locations: vec![],
513                    })
514                    .await?;
515
516                tool_id_map.borrow_mut().insert(event.call_id, result.id);
517            }
518            AcpNotification::ExecCommandEnd(event) => {
519                let acp_call_id = tool_id_map
520                    .borrow_mut()
521                    .remove(&event.call_id)
522                    .context("Missing tool call")?;
523
524                let mut content = String::new();
525                if !event.stdout.is_empty() {
526                    use std::fmt::Write;
527                    writeln!(
528                        &mut content,
529                        "### Output\n\n{}",
530                        md_codeblock("", &event.stdout)
531                    )
532                    .unwrap();
533                }
534                if !event.stdout.is_empty() && !event.stderr.is_empty() {
535                    use std::fmt::Write;
536                    writeln!(&mut content).unwrap();
537                }
538                if !event.stderr.is_empty() {
539                    use std::fmt::Write;
540                    writeln!(
541                        &mut content,
542                        "### Error\n\n{}",
543                        md_codeblock("", &event.stderr)
544                    )
545                    .unwrap();
546                }
547                let success = event.exit_code == 0;
548                if !success {
549                    use std::fmt::Write;
550                    writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
551                }
552
553                delegate
554                    .update_tool_call(acp::UpdateToolCallParams {
555                        tool_call_id: acp_call_id,
556                        status: if success {
557                            acp::ToolCallStatus::Finished
558                        } else {
559                            acp::ToolCallStatus::Error
560                        },
561                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
562                    })
563                    .await?;
564            }
565            AcpNotification::ExecApprovalRequest(event) => {
566                let inner_command = strip_bash_lc_and_escape(&event.command);
567                let root_command = inner_command
568                    .split(" ")
569                    .next()
570                    .map(|s| s.to_string())
571                    .unwrap_or_default();
572
573                let response = delegate
574                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
575                        tool_call: acp::PushToolCallParams {
576                            label: format!("`{}`", inner_command),
577                            icon: acp::Icon::Terminal,
578                            content: None,
579                            locations: vec![],
580                        },
581                        confirmation: acp::ToolCallConfirmation::Execute {
582                            command: inner_command,
583                            root_command,
584                            description: event.reason,
585                        },
586                    })
587                    .await?;
588
589                tool_id_map.borrow_mut().insert(event.call_id, response.id);
590
591                // todo! approval
592            }
593            AcpNotification::Other => {}
594        }
595
596        Ok(())
597    }
598}
599
600/// todo! use types from h2a crate when we have one
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
603#[serde(rename_all = "kebab-case")]
604pub(crate) struct CodexToolCallParam {
605    pub prompt: String,
606    pub cwd: PathBuf,
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
610struct CodexEvent {
611    pub msg: AcpNotification,
612}
613
614#[derive(Debug, Clone, Serialize, Deserialize)]
615#[serde(tag = "type", rename_all = "snake_case")]
616pub enum AcpNotification {
617    AgentMessage(AgentMessageEvent),
618    AgentReasoning(AgentReasoningEvent),
619    McpToolCallBegin(McpToolCallBeginEvent),
620    McpToolCallEnd(McpToolCallEndEvent),
621    ExecCommandBegin(ExecCommandBeginEvent),
622    ExecCommandEnd(ExecCommandEndEvent),
623    ExecApprovalRequest(ExecApprovalRequestEvent),
624    #[serde(other)]
625    Other,
626}
627
628#[derive(Debug, Clone, Serialize, Deserialize)]
629pub struct AgentMessageEvent {
630    pub message: String,
631}
632
633#[derive(Debug, Clone, Deserialize, Serialize)]
634pub struct AgentReasoningEvent {
635    pub text: String,
636}
637
638#[derive(Debug, Clone, Serialize, Deserialize)]
639pub struct McpToolCallBeginEvent {
640    pub call_id: String,
641    pub server: String,
642    pub tool: String,
643    pub arguments: Option<serde_json::Value>,
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize)]
647pub struct McpToolCallEndEvent {
648    pub call_id: String,
649    pub result: Result<serde_json::Value, String>,
650}
651
652#[derive(Debug, Clone, Serialize, Deserialize)]
653pub struct ExecCommandBeginEvent {
654    pub call_id: String,
655    pub command: Vec<String>,
656    pub cwd: PathBuf,
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
660pub struct ExecCommandEndEvent {
661    pub call_id: String,
662    pub stdout: String,
663    pub stderr: String,
664    pub exit_code: i32,
665}
666
667#[derive(Debug, Clone, Serialize, Deserialize)]
668pub struct ExecApprovalRequestEvent {
669    pub call_id: String,
670    pub command: Vec<String>,
671    pub cwd: PathBuf,
672    #[serde(skip_serializing_if = "Option::is_none")]
673    pub reason: Option<String>,
674}
675
676// Helper functions
677fn md_codeblock(lang: &str, content: &str) -> String {
678    if content.ends_with('\n') {
679        format!("```{}\n{}```", lang, content)
680    } else {
681        format!("```{}\n{}\n```", lang, content)
682    }
683}
684
685fn strip_bash_lc_and_escape(command: &[String]) -> String {
686    match command {
687        // exactly three items
688        [first, second, third]
689            // first two must be "bash", "-lc"
690            if first == "bash" && second == "-lc" =>
691        {
692            third.clone()
693        }
694        _ => escape_command(command),
695    }
696}
697
698fn escape_command(command: &[String]) -> String {
699    shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
700}
701
702fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
703    let mut content = String::new();
704
705    for chunk in chunks {
706        match chunk {
707            ToolResponseContent::Text { text } => content.push_str(&text),
708            ToolResponseContent::Image { .. } => {
709                // todo!
710            }
711            ToolResponseContent::Audio { .. } => {
712                // todo!
713            }
714            ToolResponseContent::Resource { .. } => {
715                // todo!
716            }
717        }
718    }
719
720    if !content.is_empty() {
721        Some(acp::ToolCallContent::Markdown { markdown: content })
722    } else {
723        None
724    }
725}
726
727#[cfg(test)]
728pub mod tests {
729    use super::*;
730
731    crate::common_e2e_tests!(Codex);
732
733    pub fn local_command() -> AgentServerCommand {
734        let cli_path =
735            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/code-rs/target/debug/codex");
736
737        AgentServerCommand {
738            path: cli_path,
739            args: vec!["mcp".into()],
740            env: None,
741        }
742    }
743}