codex.rs

  1use collections::HashMap;
  2use context_server::types::requests::CallTool;
  3use context_server::types::{CallToolParams, ToolResponseContent};
  4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  5use futures::channel::{mpsc, oneshot};
  6use project::Project;
  7use settings::SettingsStore;
  8use smol::stream::StreamExt;
  9use std::cell::RefCell;
 10use std::path::{Path, PathBuf};
 11use std::rc::Rc;
 12use std::sync::Arc;
 13
 14use agentic_coding_protocol::{
 15    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
 16};
 17use anyhow::{Context, Result, anyhow};
 18use futures::future::LocalBoxFuture;
 19use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 20use gpui::{App, AppContext, Entity, Task};
 21use serde::{Deserialize, Serialize};
 22use util::ResultExt;
 23
 24use crate::mcp_server::{McpConfig, ZedMcpServer};
 25use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 26use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 27
 28#[derive(Clone)]
 29pub struct Codex;
 30
 31impl AgentServer for Codex {
 32    fn name(&self) -> &'static str {
 33        "Codex"
 34    }
 35
 36    fn empty_state_headline(&self) -> &'static str {
 37        self.name()
 38    }
 39
 40    fn empty_state_message(&self) -> &'static str {
 41        ""
 42    }
 43
 44    fn logo(&self) -> ui::IconName {
 45        ui::IconName::AiOpenAi
 46    }
 47
 48    fn supports_always_allow(&self) -> bool {
 49        false
 50    }
 51
 52    fn new_thread(
 53        &self,
 54        root_dir: &Path,
 55        project: &Entity<Project>,
 56        cx: &mut App,
 57    ) -> Task<Result<Entity<AcpThread>>> {
 58        let project = project.clone();
 59        let root_dir = root_dir.to_path_buf();
 60        let title = self.name().into();
 61        cx.spawn(async move |cx| {
 62            let (mut delegate_tx, delegate_rx) = watch::channel(None);
 63            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
 64
 65            let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
 66
 67            let mut mcp_servers = HashMap::default();
 68            mcp_servers.insert(
 69                crate::mcp_server::SERVER_NAME.to_string(),
 70                zed_mcp_server.server_config()?,
 71            );
 72            let mcp_config = McpConfig { mcp_servers };
 73
 74            // todo! pass zed mcp server to codex tool
 75            let mcp_config_file = tempfile::NamedTempFile::new()?;
 76            let (mcp_config_file, _mcp_config_path) = mcp_config_file.into_parts();
 77
 78            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 79            mcp_config_file
 80                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 81                .await?;
 82            mcp_config_file.flush().await?;
 83
 84            let settings = cx.read_global(|settings: &SettingsStore, _| {
 85                settings.get::<AllAgentServersSettings>(None).codex.clone()
 86            })?;
 87
 88            let Some(command) =
 89                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 90            else {
 91                anyhow::bail!("Failed to find codex binary");
 92            };
 93
 94            let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
 95                ContextServerId("codex-mcp-server".into()),
 96                ContextServerCommand {
 97                    path: command.path,
 98                    args: command.args,
 99                    env: command.env,
100                },
101            )
102            .into();
103
104            ContextServer::start(codex_mcp_client.clone(), cx).await?;
105            // todo! stop
106
107            let (notification_tx, mut notification_rx) = mpsc::unbounded();
108
109            codex_mcp_client
110                .client()
111                .context("Failed to subscribe to server")?
112                .on_notification("codex/event", {
113                    move |event, cx| {
114                        let mut notification_tx = notification_tx.clone();
115                        cx.background_spawn(async move {
116                            log::trace!("Notification: {:?}", event);
117                            if let Some(event) =
118                                serde_json::from_value::<CodexEvent>(event).log_err()
119                            {
120                                notification_tx.send(event.msg).await.log_err();
121                            }
122                        })
123                        .detach();
124                    }
125                });
126
127            cx.new(|cx| {
128                let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
129                delegate_tx.send(Some(delegate.clone())).log_err();
130
131                let handler_task = cx.spawn({
132                    let delegate = delegate.clone();
133                    let tool_id_map = tool_id_map.clone();
134                    async move |_, _cx| {
135                        while let Some(notification) = notification_rx.next().await {
136                            CodexAgentConnection::handle_acp_notification(
137                                &delegate,
138                                notification,
139                                &tool_id_map,
140                            )
141                            .await
142                            .log_err();
143                        }
144                    }
145                });
146
147                let connection = CodexAgentConnection {
148                    root_dir,
149                    codex_mcp: codex_mcp_client,
150                    cancel_request_tx: Default::default(),
151                    tool_id_map: tool_id_map.clone(),
152                    _handler_task: handler_task,
153                    _zed_mcp: zed_mcp_server,
154                };
155
156                acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
157            })
158        })
159    }
160}
161
162impl AgentConnection for CodexAgentConnection {
163    /// Send a request to the agent and wait for a response.
164    fn request_any(
165        &self,
166        params: AnyAgentRequest,
167    ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
168        let client = self.codex_mcp.client();
169        let root_dir = self.root_dir.clone();
170        let cancel_request_tx = self.cancel_request_tx.clone();
171        async move {
172            let client = client.context("Codex MCP server is not initialized")?;
173
174            match params {
175                // todo: consider sending an empty request so we get the init response?
176                AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
177                    acp::InitializeResponse {
178                        is_authenticated: true,
179                        protocol_version: ProtocolVersion::latest(),
180                    },
181                )),
182                AnyAgentRequest::AuthenticateParams(_) => {
183                    Err(anyhow!("Authentication not supported"))
184                }
185                AnyAgentRequest::SendUserMessageParams(message) => {
186                    let (new_cancel_tx, cancel_rx) = oneshot::channel();
187                    cancel_request_tx.borrow_mut().replace(new_cancel_tx);
188
189                    client
190                        .cancellable_request::<CallTool>(
191                            CallToolParams {
192                                name: "codex".into(),
193                                arguments: Some(serde_json::to_value(CodexToolCallParam {
194                                    prompt: message
195                                        .chunks
196                                        .into_iter()
197                                        .filter_map(|chunk| match chunk {
198                                            acp::UserMessageChunk::Text { text } => Some(text),
199                                            acp::UserMessageChunk::Path { .. } => {
200                                                // todo!
201                                                None
202                                            }
203                                        })
204                                        .collect(),
205                                    cwd: root_dir,
206                                })?),
207                                meta: None,
208                            },
209                            cancel_rx,
210                        )
211                        .await?;
212
213                    Ok(AnyAgentResult::SendUserMessageResponse(
214                        acp::SendUserMessageResponse,
215                    ))
216                }
217                AnyAgentRequest::CancelSendMessageParams(_) => {
218                    if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
219                        if let Some(cancel_tx) = borrow.take() {
220                            cancel_tx.send(()).ok();
221                        }
222                    }
223
224                    Ok(AnyAgentResult::CancelSendMessageResponse(
225                        acp::CancelSendMessageResponse,
226                    ))
227                }
228            }
229        }
230        .boxed_local()
231    }
232}
233
234struct CodexAgentConnection {
235    codex_mcp: Arc<context_server::ContextServer>,
236    root_dir: PathBuf,
237    cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
238    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
239    _handler_task: Task<()>,
240    _zed_mcp: ZedMcpServer,
241}
242
243impl CodexAgentConnection {
244    async fn handle_acp_notification(
245        delegate: &AcpClientDelegate,
246        event: AcpNotification,
247        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
248    ) -> Result<()> {
249        match event {
250            AcpNotification::AgentMessage(message) => {
251                delegate
252                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
253                        chunk: acp::AssistantMessageChunk::Text {
254                            text: message.message,
255                        },
256                    })
257                    .await?;
258            }
259            AcpNotification::AgentReasoning(message) => {
260                delegate
261                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
262                        chunk: acp::AssistantMessageChunk::Thought {
263                            thought: message.text,
264                        },
265                    })
266                    .await?
267            }
268            AcpNotification::McpToolCallBegin(event) => {
269                let result = delegate
270                    .push_tool_call(acp::PushToolCallParams {
271                        label: format!("`{}: {}`", event.server, event.tool),
272                        icon: acp::Icon::Hammer,
273                        content: event.arguments.and_then(|args| {
274                            Some(acp::ToolCallContent::Markdown {
275                                markdown: md_codeblock(
276                                    "json",
277                                    &serde_json::to_string_pretty(&args).ok()?,
278                                ),
279                            })
280                        }),
281                        locations: vec![],
282                    })
283                    .await?;
284
285                tool_id_map.borrow_mut().insert(event.call_id, result.id);
286            }
287            AcpNotification::McpToolCallEnd(event) => {
288                let acp_call_id = tool_id_map
289                    .borrow_mut()
290                    .remove(&event.call_id)
291                    .context("Missing tool call")?;
292
293                let (status, content) = match event.result {
294                    Ok(value) => {
295                        if let Ok(response) =
296                            serde_json::from_value::<context_server::types::CallToolResponse>(value)
297                        {
298                            (
299                                acp::ToolCallStatus::Finished,
300                                mcp_tool_content_to_acp(response.content),
301                            )
302                        } else {
303                            (
304                                acp::ToolCallStatus::Error,
305                                Some(acp::ToolCallContent::Markdown {
306                                    markdown: "Failed to parse tool response".to_string(),
307                                }),
308                            )
309                        }
310                    }
311                    Err(error) => (
312                        acp::ToolCallStatus::Error,
313                        Some(acp::ToolCallContent::Markdown { markdown: error }),
314                    ),
315                };
316
317                delegate
318                    .update_tool_call(acp::UpdateToolCallParams {
319                        tool_call_id: acp_call_id,
320                        status,
321                        content,
322                    })
323                    .await?;
324            }
325            AcpNotification::ExecCommandBegin(event) => {
326                let inner_command = strip_bash_lc_and_escape(&event.command);
327
328                let result = delegate
329                    .push_tool_call(acp::PushToolCallParams {
330                        label: format!("`{}`", inner_command),
331                        icon: acp::Icon::Terminal,
332                        content: None,
333                        locations: vec![],
334                    })
335                    .await?;
336
337                tool_id_map.borrow_mut().insert(event.call_id, result.id);
338            }
339            AcpNotification::ExecCommandEnd(event) => {
340                let acp_call_id = tool_id_map
341                    .borrow_mut()
342                    .remove(&event.call_id)
343                    .context("Missing tool call")?;
344
345                let mut content = String::new();
346                if !event.stdout.is_empty() {
347                    use std::fmt::Write;
348                    writeln!(
349                        &mut content,
350                        "### Output\n\n{}",
351                        md_codeblock("", &event.stdout)
352                    )
353                    .unwrap();
354                }
355                if !event.stdout.is_empty() && !event.stderr.is_empty() {
356                    use std::fmt::Write;
357                    writeln!(&mut content).unwrap();
358                }
359                if !event.stderr.is_empty() {
360                    use std::fmt::Write;
361                    writeln!(
362                        &mut content,
363                        "### Error\n\n{}",
364                        md_codeblock("", &event.stderr)
365                    )
366                    .unwrap();
367                }
368                let success = event.exit_code == 0;
369                if !success {
370                    use std::fmt::Write;
371                    writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
372                }
373
374                delegate
375                    .update_tool_call(acp::UpdateToolCallParams {
376                        tool_call_id: acp_call_id,
377                        status: if success {
378                            acp::ToolCallStatus::Finished
379                        } else {
380                            acp::ToolCallStatus::Error
381                        },
382                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
383                    })
384                    .await?;
385            }
386            AcpNotification::ExecApprovalRequest(event) => {
387                let inner_command = strip_bash_lc_and_escape(&event.command);
388                let root_command = inner_command
389                    .split(" ")
390                    .next()
391                    .map(|s| s.to_string())
392                    .unwrap_or_default();
393
394                let response = delegate
395                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
396                        tool_call: acp::PushToolCallParams {
397                            label: format!("`{}`", inner_command),
398                            icon: acp::Icon::Terminal,
399                            content: None,
400                            locations: vec![],
401                        },
402                        confirmation: acp::ToolCallConfirmation::Execute {
403                            command: inner_command,
404                            root_command,
405                            description: event.reason,
406                        },
407                    })
408                    .await?;
409
410                tool_id_map.borrow_mut().insert(event.call_id, response.id);
411
412                // todo! approval
413            }
414            AcpNotification::Other => {}
415        }
416
417        Ok(())
418    }
419}
420
421/// todo! use types from h2a crate when we have one
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
424#[serde(rename_all = "kebab-case")]
425pub(crate) struct CodexToolCallParam {
426    pub prompt: String,
427    pub cwd: PathBuf,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
431struct CodexEvent {
432    pub msg: AcpNotification,
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
436#[serde(tag = "type", rename_all = "snake_case")]
437pub enum AcpNotification {
438    AgentMessage(AgentMessageEvent),
439    AgentReasoning(AgentReasoningEvent),
440    McpToolCallBegin(McpToolCallBeginEvent),
441    McpToolCallEnd(McpToolCallEndEvent),
442    ExecCommandBegin(ExecCommandBeginEvent),
443    ExecCommandEnd(ExecCommandEndEvent),
444    ExecApprovalRequest(ExecApprovalRequestEvent),
445    #[serde(other)]
446    Other,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct AgentMessageEvent {
451    pub message: String,
452}
453
454#[derive(Debug, Clone, Deserialize, Serialize)]
455pub struct AgentReasoningEvent {
456    pub text: String,
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct McpToolCallBeginEvent {
461    pub call_id: String,
462    pub server: String,
463    pub tool: String,
464    pub arguments: Option<serde_json::Value>,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct McpToolCallEndEvent {
469    pub call_id: String,
470    pub result: Result<serde_json::Value, String>,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct ExecCommandBeginEvent {
475    pub call_id: String,
476    pub command: Vec<String>,
477    pub cwd: PathBuf,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct ExecCommandEndEvent {
482    pub call_id: String,
483    pub stdout: String,
484    pub stderr: String,
485    pub exit_code: i32,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct ExecApprovalRequestEvent {
490    pub call_id: String,
491    pub command: Vec<String>,
492    pub cwd: PathBuf,
493    #[serde(skip_serializing_if = "Option::is_none")]
494    pub reason: Option<String>,
495}
496
497// Helper functions
498fn md_codeblock(lang: &str, content: &str) -> String {
499    if content.ends_with('\n') {
500        format!("```{}\n{}```", lang, content)
501    } else {
502        format!("```{}\n{}\n```", lang, content)
503    }
504}
505
506fn strip_bash_lc_and_escape(command: &[String]) -> String {
507    match command {
508        // exactly three items
509        [first, second, third]
510            // first two must be "bash", "-lc"
511            if first == "bash" && second == "-lc" =>
512        {
513            third.clone()
514        }
515        _ => escape_command(command),
516    }
517}
518
519fn escape_command(command: &[String]) -> String {
520    shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
521}
522
523fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
524    let mut content = String::new();
525
526    for chunk in chunks {
527        match chunk {
528            ToolResponseContent::Text { text } => content.push_str(&text),
529            ToolResponseContent::Image { .. } => {
530                // todo!
531            }
532            ToolResponseContent::Audio { .. } => {
533                // todo!
534            }
535            ToolResponseContent::Resource { .. } => {
536                // todo!
537            }
538        }
539    }
540
541    if !content.is_empty() {
542        Some(acp::ToolCallContent::Markdown { markdown: content })
543    } else {
544        None
545    }
546}
547
548#[cfg(test)]
549pub mod tests {
550    use super::*;
551
552    pub fn local_command() -> AgentServerCommand {
553        let cli_path =
554            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/target/debug/codex");
555
556        AgentServerCommand {
557            path: cli_path,
558            args: vec!["mcp".into()],
559            env: None,
560        }
561    }
562}