codex.rs

  1use collections::HashMap;
  2use context_server::types::requests::CallTool;
  3use context_server::types::{CallToolParams, ToolResponseContent};
  4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  5use futures::channel::{mpsc, oneshot};
  6use project::Project;
  7use settings::SettingsStore;
  8use smol::stream::StreamExt;
  9use std::cell::RefCell;
 10use std::path::{Path, PathBuf};
 11use std::rc::Rc;
 12use std::sync::Arc;
 13
 14use agentic_coding_protocol::{
 15    self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
 16};
 17use anyhow::{Context, Result, anyhow};
 18use futures::future::LocalBoxFuture;
 19use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
 20use gpui::{App, AppContext, Entity, Task};
 21use serde::{Deserialize, Serialize};
 22use util::ResultExt;
 23
 24use crate::mcp_server::{McpConfig, ZedMcpServer};
 25use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
 26use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
 27
 28#[derive(Clone)]
 29pub struct Codex;
 30
 31impl AgentServer for Codex {
 32    fn name(&self) -> &'static str {
 33        "Codex"
 34    }
 35
 36    fn empty_state_headline(&self) -> &'static str {
 37        self.name()
 38    }
 39
 40    fn empty_state_message(&self) -> &'static str {
 41        ""
 42    }
 43
 44    fn logo(&self) -> ui::IconName {
 45        ui::IconName::AiOpenAi
 46    }
 47
 48    fn supports_always_allow(&self) -> bool {
 49        false
 50    }
 51
 52    fn new_thread(
 53        &self,
 54        root_dir: &Path,
 55        project: &Entity<Project>,
 56        cx: &mut App,
 57    ) -> Task<Result<Entity<AcpThread>>> {
 58        let project = project.clone();
 59        let root_dir = root_dir.to_path_buf();
 60        let title = self.name().into();
 61        cx.spawn(async move |cx| {
 62            let (mut delegate_tx, delegate_rx) = watch::channel(None);
 63            let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
 64
 65            let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
 66
 67            let mut mcp_servers = HashMap::default();
 68            mcp_servers.insert(
 69                crate::mcp_server::SERVER_NAME.to_string(),
 70                zed_mcp_server.server_config()?,
 71            );
 72            let mcp_config = McpConfig { mcp_servers };
 73
 74            // todo! pass zed mcp server to codex tool
 75            let mcp_config_file = tempfile::NamedTempFile::new()?;
 76            let (mcp_config_file, _mcp_config_path) = mcp_config_file.into_parts();
 77
 78            let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
 79            mcp_config_file
 80                .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
 81                .await?;
 82            mcp_config_file.flush().await?;
 83
 84            let settings = cx.read_global(|settings: &SettingsStore, _| {
 85                settings.get::<AllAgentServersSettings>(None).codex.clone()
 86            })?;
 87
 88            let Some(command) =
 89                AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
 90            else {
 91                anyhow::bail!("Failed to find codex binary");
 92            };
 93
 94            let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
 95                ContextServerId("codex-mcp-server".into()),
 96                ContextServerCommand {
 97                    // todo! should we change ContextServerCommand to take a PathBuf?
 98                    path: command.path.to_string_lossy().to_string(),
 99                    args: command.args,
100                    env: command.env,
101                },
102            )
103            .into();
104
105            ContextServer::start(codex_mcp_client.clone(), cx).await?;
106            // todo! stop
107
108            let (notification_tx, mut notification_rx) = mpsc::unbounded();
109
110            codex_mcp_client
111                .client()
112                .context("Failed to subscribe to server")?
113                .on_notification("codex/event", {
114                    move |event, cx| {
115                        let mut notification_tx = notification_tx.clone();
116                        cx.background_spawn(async move {
117                            log::trace!("Notification: {:?}", event);
118                            if let Some(event) =
119                                serde_json::from_value::<CodexEvent>(event).log_err()
120                            {
121                                notification_tx.send(event.msg).await.log_err();
122                            }
123                        })
124                        .detach();
125                    }
126                });
127
128            cx.new(|cx| {
129                // todo! handle notifications
130                let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
131                delegate_tx.send(Some(delegate.clone())).log_err();
132
133                let handler_task = cx.spawn({
134                    let delegate = delegate.clone();
135                    let tool_id_map = tool_id_map.clone();
136                    async move |_, _cx| {
137                        while let Some(notification) = notification_rx.next().await {
138                            CodexAgentConnection::handle_acp_notification(
139                                &delegate,
140                                notification,
141                                &tool_id_map,
142                            )
143                            .await
144                            .log_err();
145                        }
146                    }
147                });
148
149                let connection = CodexAgentConnection {
150                    root_dir,
151                    codex_mcp: codex_mcp_client,
152                    cancel_request_tx: Default::default(),
153                    tool_id_map: tool_id_map.clone(),
154                    _handler_task: handler_task,
155                    _zed_mcp: zed_mcp_server,
156                };
157
158                acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
159            })
160        })
161    }
162}
163
164impl AgentConnection for CodexAgentConnection {
165    /// Send a request to the agent and wait for a response.
166    fn request_any(
167        &self,
168        params: AnyAgentRequest,
169    ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
170        let client = self.codex_mcp.client();
171        let root_dir = self.root_dir.clone();
172        let cancel_request_tx = self.cancel_request_tx.clone();
173        async move {
174            let client = client.context("Codex MCP server is not initialized")?;
175
176            match params {
177                // todo: consider sending an empty request so we get the init response?
178                AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
179                    acp::InitializeResponse {
180                        is_authenticated: true,
181                        protocol_version: ProtocolVersion::latest(),
182                    },
183                )),
184                AnyAgentRequest::AuthenticateParams(_) => {
185                    Err(anyhow!("Authentication not supported"))
186                }
187                AnyAgentRequest::SendUserMessageParams(message) => {
188                    let (new_cancel_tx, cancel_rx) = oneshot::channel();
189                    cancel_request_tx.borrow_mut().replace(new_cancel_tx);
190
191                    client
192                        .cancellable_request::<CallTool>(
193                            CallToolParams {
194                                name: "codex".into(),
195                                arguments: Some(serde_json::to_value(CodexToolCallParam {
196                                    prompt: message
197                                        .chunks
198                                        .into_iter()
199                                        .filter_map(|chunk| match chunk {
200                                            acp::UserMessageChunk::Text { text } => Some(text),
201                                            acp::UserMessageChunk::Path { .. } => {
202                                                // todo!
203                                                None
204                                            }
205                                        })
206                                        .collect(),
207                                    cwd: root_dir,
208                                })?),
209                                meta: None,
210                            },
211                            cancel_rx,
212                        )
213                        .await?;
214
215                    Ok(AnyAgentResult::SendUserMessageResponse(
216                        acp::SendUserMessageResponse,
217                    ))
218                }
219                AnyAgentRequest::CancelSendMessageParams(_) => {
220                    if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
221                        if let Some(cancel_tx) = borrow.take() {
222                            cancel_tx.send(()).ok();
223                        }
224                    }
225
226                    Ok(AnyAgentResult::CancelSendMessageResponse(
227                        acp::CancelSendMessageResponse,
228                    ))
229                }
230            }
231        }
232        .boxed_local()
233    }
234}
235
236struct CodexAgentConnection {
237    codex_mcp: Arc<context_server::ContextServer>,
238    root_dir: PathBuf,
239    cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
240    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
241    _handler_task: Task<()>,
242    _zed_mcp: ZedMcpServer,
243}
244
245impl CodexAgentConnection {
246    async fn handle_acp_notification(
247        delegate: &AcpClientDelegate,
248        event: AcpNotification,
249        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
250    ) -> Result<()> {
251        match event {
252            AcpNotification::AgentMessage(message) => {
253                delegate
254                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
255                        chunk: acp::AssistantMessageChunk::Text {
256                            text: message.message,
257                        },
258                    })
259                    .await?;
260            }
261            AcpNotification::AgentReasoning(message) => {
262                delegate
263                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
264                        chunk: acp::AssistantMessageChunk::Thought {
265                            thought: message.text,
266                        },
267                    })
268                    .await?
269            }
270            AcpNotification::McpToolCallBegin(event) => {
271                let result = delegate
272                    .push_tool_call(acp::PushToolCallParams {
273                        label: format!("`{}: {}`", event.server, event.tool),
274                        icon: acp::Icon::Hammer,
275                        content: event.arguments.and_then(|args| {
276                            Some(acp::ToolCallContent::Markdown {
277                                markdown: md_codeblock(
278                                    "json",
279                                    &serde_json::to_string_pretty(&args).ok()?,
280                                ),
281                            })
282                        }),
283                        locations: vec![],
284                    })
285                    .await?;
286
287                tool_id_map.borrow_mut().insert(event.call_id, result.id);
288            }
289            AcpNotification::McpToolCallEnd(event) => {
290                let acp_call_id = tool_id_map
291                    .borrow_mut()
292                    .remove(&event.call_id)
293                    .context("Missing tool call")?;
294
295                let (status, content) = match event.result {
296                    Ok(value) => {
297                        if let Ok(response) =
298                            serde_json::from_value::<context_server::types::CallToolResponse>(value)
299                        {
300                            (
301                                acp::ToolCallStatus::Finished,
302                                mcp_tool_content_to_acp(response.content),
303                            )
304                        } else {
305                            (
306                                acp::ToolCallStatus::Error,
307                                Some(acp::ToolCallContent::Markdown {
308                                    markdown: "Failed to parse tool response".to_string(),
309                                }),
310                            )
311                        }
312                    }
313                    Err(error) => (
314                        acp::ToolCallStatus::Error,
315                        Some(acp::ToolCallContent::Markdown { markdown: error }),
316                    ),
317                };
318
319                delegate
320                    .update_tool_call(acp::UpdateToolCallParams {
321                        tool_call_id: acp_call_id,
322                        status,
323                        content,
324                    })
325                    .await?;
326            }
327            AcpNotification::ExecCommandBegin(event) => {
328                let inner_command = strip_bash_lc_and_escape(&event.command);
329
330                let result = delegate
331                    .push_tool_call(acp::PushToolCallParams {
332                        label: format!("`{}`", inner_command),
333                        icon: acp::Icon::Terminal,
334                        content: None,
335                        locations: vec![],
336                    })
337                    .await?;
338
339                tool_id_map.borrow_mut().insert(event.call_id, result.id);
340            }
341            AcpNotification::ExecCommandEnd(event) => {
342                let acp_call_id = tool_id_map
343                    .borrow_mut()
344                    .remove(&event.call_id)
345                    .context("Missing tool call")?;
346
347                let mut content = String::new();
348                if !event.stdout.is_empty() {
349                    use std::fmt::Write;
350                    writeln!(
351                        &mut content,
352                        "### Output\n\n{}",
353                        md_codeblock("", &event.stdout)
354                    )
355                    .unwrap();
356                }
357                if !event.stdout.is_empty() && !event.stderr.is_empty() {
358                    use std::fmt::Write;
359                    writeln!(&mut content).unwrap();
360                }
361                if !event.stderr.is_empty() {
362                    use std::fmt::Write;
363                    writeln!(
364                        &mut content,
365                        "### Error\n\n{}",
366                        md_codeblock("", &event.stderr)
367                    )
368                    .unwrap();
369                }
370                let success = event.exit_code == 0;
371                if !success {
372                    use std::fmt::Write;
373                    writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
374                }
375
376                delegate
377                    .update_tool_call(acp::UpdateToolCallParams {
378                        tool_call_id: acp_call_id,
379                        status: if success {
380                            acp::ToolCallStatus::Finished
381                        } else {
382                            acp::ToolCallStatus::Error
383                        },
384                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
385                    })
386                    .await?;
387            }
388            AcpNotification::ExecApprovalRequest(event) => {
389                let inner_command = strip_bash_lc_and_escape(&event.command);
390                let root_command = inner_command
391                    .split(" ")
392                    .next()
393                    .map(|s| s.to_string())
394                    .unwrap_or_default();
395
396                let response = delegate
397                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
398                        tool_call: acp::PushToolCallParams {
399                            label: format!("`{}`", inner_command),
400                            icon: acp::Icon::Terminal,
401                            content: None,
402                            locations: vec![],
403                        },
404                        confirmation: acp::ToolCallConfirmation::Execute {
405                            command: inner_command,
406                            root_command,
407                            description: event.reason,
408                        },
409                    })
410                    .await?;
411
412                tool_id_map.borrow_mut().insert(event.call_id, response.id);
413
414                // todo! approval
415            }
416            AcpNotification::Other => {}
417        }
418
419        Ok(())
420    }
421}
422
423/// todo! use types from h2a crate when we have one
424
425#[derive(Debug, Clone, Serialize, Deserialize)]
426#[serde(rename_all = "kebab-case")]
427pub(crate) struct CodexToolCallParam {
428    pub prompt: String,
429    pub cwd: PathBuf,
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
433struct CodexEvent {
434    pub msg: AcpNotification,
435}
436
437#[derive(Debug, Clone, Serialize, Deserialize)]
438#[serde(tag = "type", rename_all = "snake_case")]
439pub enum AcpNotification {
440    AgentMessage(AgentMessageEvent),
441    AgentReasoning(AgentReasoningEvent),
442    McpToolCallBegin(McpToolCallBeginEvent),
443    McpToolCallEnd(McpToolCallEndEvent),
444    ExecCommandBegin(ExecCommandBeginEvent),
445    ExecCommandEnd(ExecCommandEndEvent),
446    ExecApprovalRequest(ExecApprovalRequestEvent),
447    #[serde(other)]
448    Other,
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct AgentMessageEvent {
453    pub message: String,
454}
455
456#[derive(Debug, Clone, Deserialize, Serialize)]
457pub struct AgentReasoningEvent {
458    pub text: String,
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct McpToolCallBeginEvent {
463    pub call_id: String,
464    pub server: String,
465    pub tool: String,
466    pub arguments: Option<serde_json::Value>,
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
470pub struct McpToolCallEndEvent {
471    pub call_id: String,
472    pub result: Result<serde_json::Value, String>,
473}
474
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ExecCommandBeginEvent {
477    pub call_id: String,
478    pub command: Vec<String>,
479    pub cwd: PathBuf,
480}
481
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub struct ExecCommandEndEvent {
484    pub call_id: String,
485    pub stdout: String,
486    pub stderr: String,
487    pub exit_code: i32,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
491pub struct ExecApprovalRequestEvent {
492    pub call_id: String,
493    pub command: Vec<String>,
494    pub cwd: PathBuf,
495    #[serde(skip_serializing_if = "Option::is_none")]
496    pub reason: Option<String>,
497}
498
499// Helper functions
500fn md_codeblock(lang: &str, content: &str) -> String {
501    if content.ends_with('\n') {
502        format!("```{}\n{}```", lang, content)
503    } else {
504        format!("```{}\n{}\n```", lang, content)
505    }
506}
507
508fn strip_bash_lc_and_escape(command: &[String]) -> String {
509    match command {
510        // exactly three items
511        [first, second, third]
512            // first two must be "bash", "-lc"
513            if first == "bash" && second == "-lc" =>
514        {
515            third.clone()
516        }
517        _ => escape_command(command),
518    }
519}
520
521fn escape_command(command: &[String]) -> String {
522    shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
523}
524
525fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
526    let mut content = String::new();
527
528    for chunk in chunks {
529        match chunk {
530            ToolResponseContent::Text { text } => content.push_str(&text),
531            ToolResponseContent::Image { .. } => {
532                // todo!
533            }
534            ToolResponseContent::Audio { .. } => {
535                // todo!
536            }
537            ToolResponseContent::Resource { .. } => {
538                // todo!
539            }
540        }
541    }
542
543    if !content.is_empty() {
544        Some(acp::ToolCallContent::Markdown { markdown: content })
545    } else {
546        None
547    }
548}