1use collections::HashMap;
2use context_server::types::requests::CallTool;
3use context_server::types::{CallToolParams, ToolResponseContent};
4use context_server::{ContextServer, ContextServerCommand, ContextServerId};
5use futures::channel::{mpsc, oneshot};
6use project::Project;
7use settings::SettingsStore;
8use smol::stream::StreamExt;
9use std::cell::RefCell;
10use std::path::{Path, PathBuf};
11use std::rc::Rc;
12use std::sync::Arc;
13
14use agentic_coding_protocol::{
15 self as acp, AnyAgentRequest, AnyAgentResult, Client as _, ProtocolVersion,
16};
17use anyhow::{Context, Result, anyhow};
18use futures::future::LocalBoxFuture;
19use futures::{AsyncWriteExt, FutureExt, SinkExt as _};
20use gpui::{App, AppContext, Entity, Task};
21use serde::{Deserialize, Serialize};
22use util::ResultExt;
23
24use crate::mcp_server::{McpConfig, ZedMcpServer};
25use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings};
26use acp_thread::{AcpClientDelegate, AcpThread, AgentConnection};
27
28#[derive(Clone)]
29pub struct Codex;
30
31impl AgentServer for Codex {
32 fn name(&self) -> &'static str {
33 "Codex"
34 }
35
36 fn empty_state_headline(&self) -> &'static str {
37 self.name()
38 }
39
40 fn empty_state_message(&self) -> &'static str {
41 ""
42 }
43
44 fn logo(&self) -> ui::IconName {
45 ui::IconName::AiOpenAi
46 }
47
48 fn supports_always_allow(&self) -> bool {
49 false
50 }
51
52 fn new_thread(
53 &self,
54 root_dir: &Path,
55 project: &Entity<Project>,
56 cx: &mut App,
57 ) -> Task<Result<Entity<AcpThread>>> {
58 let project = project.clone();
59 let root_dir = root_dir.to_path_buf();
60 let title = self.name().into();
61 cx.spawn(async move |cx| {
62 let (mut delegate_tx, delegate_rx) = watch::channel(None);
63 let tool_id_map = Rc::new(RefCell::new(HashMap::default()));
64
65 let zed_mcp_server = ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?;
66
67 let mut mcp_servers = HashMap::default();
68 mcp_servers.insert(
69 crate::mcp_server::SERVER_NAME.to_string(),
70 zed_mcp_server.server_config()?,
71 );
72 let mcp_config = McpConfig { mcp_servers };
73
74 // todo! pass zed mcp server to codex tool
75 let mcp_config_file = tempfile::NamedTempFile::new()?;
76 let (mcp_config_file, _mcp_config_path) = mcp_config_file.into_parts();
77
78 let mut mcp_config_file = smol::fs::File::from(mcp_config_file);
79 mcp_config_file
80 .write_all(serde_json::to_string(&mcp_config)?.as_bytes())
81 .await?;
82 mcp_config_file.flush().await?;
83
84 let settings = cx.read_global(|settings: &SettingsStore, _| {
85 settings.get::<AllAgentServersSettings>(None).codex.clone()
86 })?;
87
88 let Some(command) =
89 AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
90 else {
91 anyhow::bail!("Failed to find codex binary");
92 };
93
94 let codex_mcp_client: Arc<ContextServer> = ContextServer::stdio(
95 ContextServerId("codex-mcp-server".into()),
96 ContextServerCommand {
97 path: command.path,
98 args: command.args,
99 env: command.env,
100 },
101 )
102 .into();
103
104 ContextServer::start(codex_mcp_client.clone(), cx).await?;
105 // todo! stop
106
107 let (notification_tx, mut notification_rx) = mpsc::unbounded();
108
109 codex_mcp_client
110 .client()
111 .context("Failed to subscribe to server")?
112 .on_notification("codex/event", {
113 move |event, cx| {
114 let mut notification_tx = notification_tx.clone();
115 cx.background_spawn(async move {
116 log::trace!("Notification: {:?}", event);
117 if let Some(event) =
118 serde_json::from_value::<CodexEvent>(event).log_err()
119 {
120 notification_tx.send(event.msg).await.log_err();
121 }
122 })
123 .detach();
124 }
125 });
126
127 cx.new(|cx| {
128 let delegate = AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async());
129 delegate_tx.send(Some(delegate.clone())).log_err();
130
131 let handler_task = cx.spawn({
132 let delegate = delegate.clone();
133 let tool_id_map = tool_id_map.clone();
134 async move |_, _cx| {
135 while let Some(notification) = notification_rx.next().await {
136 CodexAgentConnection::handle_acp_notification(
137 &delegate,
138 notification,
139 &tool_id_map,
140 )
141 .await
142 .log_err();
143 }
144 }
145 });
146
147 let connection = CodexAgentConnection {
148 root_dir,
149 codex_mcp: codex_mcp_client,
150 cancel_request_tx: Default::default(),
151 tool_id_map: tool_id_map.clone(),
152 _handler_task: handler_task,
153 _zed_mcp: zed_mcp_server,
154 };
155
156 acp_thread::AcpThread::new(connection, title, None, project.clone(), cx)
157 })
158 })
159 }
160}
161
162impl AgentConnection for CodexAgentConnection {
163 /// Send a request to the agent and wait for a response.
164 fn request_any(
165 &self,
166 params: AnyAgentRequest,
167 ) -> LocalBoxFuture<'static, Result<acp::AnyAgentResult>> {
168 let client = self.codex_mcp.client();
169 let root_dir = self.root_dir.clone();
170 let cancel_request_tx = self.cancel_request_tx.clone();
171 async move {
172 let client = client.context("Codex MCP server is not initialized")?;
173
174 match params {
175 // todo: consider sending an empty request so we get the init response?
176 AnyAgentRequest::InitializeParams(_) => Ok(AnyAgentResult::InitializeResponse(
177 acp::InitializeResponse {
178 is_authenticated: true,
179 protocol_version: ProtocolVersion::latest(),
180 },
181 )),
182 AnyAgentRequest::AuthenticateParams(_) => {
183 Err(anyhow!("Authentication not supported"))
184 }
185 AnyAgentRequest::SendUserMessageParams(message) => {
186 let (new_cancel_tx, cancel_rx) = oneshot::channel();
187 cancel_request_tx.borrow_mut().replace(new_cancel_tx);
188
189 client
190 .cancellable_request::<CallTool>(
191 CallToolParams {
192 name: "codex".into(),
193 arguments: Some(serde_json::to_value(CodexToolCallParam {
194 prompt: message
195 .chunks
196 .into_iter()
197 .filter_map(|chunk| match chunk {
198 acp::UserMessageChunk::Text { text } => Some(text),
199 acp::UserMessageChunk::Path { .. } => {
200 // todo!
201 None
202 }
203 })
204 .collect(),
205 cwd: root_dir,
206 })?),
207 meta: None,
208 },
209 cancel_rx,
210 )
211 .await?;
212
213 Ok(AnyAgentResult::SendUserMessageResponse(
214 acp::SendUserMessageResponse,
215 ))
216 }
217 AnyAgentRequest::CancelSendMessageParams(_) => {
218 if let Ok(mut borrow) = cancel_request_tx.try_borrow_mut() {
219 if let Some(cancel_tx) = borrow.take() {
220 cancel_tx.send(()).ok();
221 }
222 }
223
224 Ok(AnyAgentResult::CancelSendMessageResponse(
225 acp::CancelSendMessageResponse,
226 ))
227 }
228 }
229 }
230 .boxed_local()
231 }
232}
233
234struct CodexAgentConnection {
235 codex_mcp: Arc<context_server::ContextServer>,
236 root_dir: PathBuf,
237 cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
238 tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
239 _handler_task: Task<()>,
240 _zed_mcp: ZedMcpServer,
241}
242
243impl CodexAgentConnection {
244 async fn handle_acp_notification(
245 delegate: &AcpClientDelegate,
246 event: AcpNotification,
247 tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
248 ) -> Result<()> {
249 match event {
250 AcpNotification::AgentMessage(message) => {
251 delegate
252 .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
253 chunk: acp::AssistantMessageChunk::Text {
254 text: message.message,
255 },
256 })
257 .await?;
258 }
259 AcpNotification::AgentReasoning(message) => {
260 delegate
261 .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
262 chunk: acp::AssistantMessageChunk::Thought {
263 thought: message.text,
264 },
265 })
266 .await?
267 }
268 AcpNotification::McpToolCallBegin(event) => {
269 let result = delegate
270 .push_tool_call(acp::PushToolCallParams {
271 label: format!("`{}: {}`", event.server, event.tool),
272 icon: acp::Icon::Hammer,
273 content: event.arguments.and_then(|args| {
274 Some(acp::ToolCallContent::Markdown {
275 markdown: md_codeblock(
276 "json",
277 &serde_json::to_string_pretty(&args).ok()?,
278 ),
279 })
280 }),
281 locations: vec![],
282 })
283 .await?;
284
285 tool_id_map.borrow_mut().insert(event.call_id, result.id);
286 }
287 AcpNotification::McpToolCallEnd(event) => {
288 let acp_call_id = tool_id_map
289 .borrow_mut()
290 .remove(&event.call_id)
291 .context("Missing tool call")?;
292
293 let (status, content) = match event.result {
294 Ok(value) => {
295 if let Ok(response) =
296 serde_json::from_value::<context_server::types::CallToolResponse>(value)
297 {
298 (
299 acp::ToolCallStatus::Finished,
300 mcp_tool_content_to_acp(response.content),
301 )
302 } else {
303 (
304 acp::ToolCallStatus::Error,
305 Some(acp::ToolCallContent::Markdown {
306 markdown: "Failed to parse tool response".to_string(),
307 }),
308 )
309 }
310 }
311 Err(error) => (
312 acp::ToolCallStatus::Error,
313 Some(acp::ToolCallContent::Markdown { markdown: error }),
314 ),
315 };
316
317 delegate
318 .update_tool_call(acp::UpdateToolCallParams {
319 tool_call_id: acp_call_id,
320 status,
321 content,
322 })
323 .await?;
324 }
325 AcpNotification::ExecCommandBegin(event) => {
326 let inner_command = strip_bash_lc_and_escape(&event.command);
327
328 let result = delegate
329 .push_tool_call(acp::PushToolCallParams {
330 label: format!("`{}`", inner_command),
331 icon: acp::Icon::Terminal,
332 content: None,
333 locations: vec![],
334 })
335 .await?;
336
337 tool_id_map.borrow_mut().insert(event.call_id, result.id);
338 }
339 AcpNotification::ExecCommandEnd(event) => {
340 let acp_call_id = tool_id_map
341 .borrow_mut()
342 .remove(&event.call_id)
343 .context("Missing tool call")?;
344
345 let mut content = String::new();
346 if !event.stdout.is_empty() {
347 use std::fmt::Write;
348 writeln!(
349 &mut content,
350 "### Output\n\n{}",
351 md_codeblock("", &event.stdout)
352 )
353 .unwrap();
354 }
355 if !event.stdout.is_empty() && !event.stderr.is_empty() {
356 use std::fmt::Write;
357 writeln!(&mut content).unwrap();
358 }
359 if !event.stderr.is_empty() {
360 use std::fmt::Write;
361 writeln!(
362 &mut content,
363 "### Error\n\n{}",
364 md_codeblock("", &event.stderr)
365 )
366 .unwrap();
367 }
368 let success = event.exit_code == 0;
369 if !success {
370 use std::fmt::Write;
371 writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
372 }
373
374 delegate
375 .update_tool_call(acp::UpdateToolCallParams {
376 tool_call_id: acp_call_id,
377 status: if success {
378 acp::ToolCallStatus::Finished
379 } else {
380 acp::ToolCallStatus::Error
381 },
382 content: Some(acp::ToolCallContent::Markdown { markdown: content }),
383 })
384 .await?;
385 }
386 AcpNotification::ExecApprovalRequest(event) => {
387 let inner_command = strip_bash_lc_and_escape(&event.command);
388 let root_command = inner_command
389 .split(" ")
390 .next()
391 .map(|s| s.to_string())
392 .unwrap_or_default();
393
394 let response = delegate
395 .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
396 tool_call: acp::PushToolCallParams {
397 label: format!("`{}`", inner_command),
398 icon: acp::Icon::Terminal,
399 content: None,
400 locations: vec![],
401 },
402 confirmation: acp::ToolCallConfirmation::Execute {
403 command: inner_command,
404 root_command,
405 description: event.reason,
406 },
407 })
408 .await?;
409
410 tool_id_map.borrow_mut().insert(event.call_id, response.id);
411
412 // todo! approval
413 }
414 AcpNotification::Other => {}
415 }
416
417 Ok(())
418 }
419}
420
421/// todo! use types from h2a crate when we have one
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
424#[serde(rename_all = "kebab-case")]
425pub(crate) struct CodexToolCallParam {
426 pub prompt: String,
427 pub cwd: PathBuf,
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
431struct CodexEvent {
432 pub msg: AcpNotification,
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
436#[serde(tag = "type", rename_all = "snake_case")]
437pub enum AcpNotification {
438 AgentMessage(AgentMessageEvent),
439 AgentReasoning(AgentReasoningEvent),
440 McpToolCallBegin(McpToolCallBeginEvent),
441 McpToolCallEnd(McpToolCallEndEvent),
442 ExecCommandBegin(ExecCommandBeginEvent),
443 ExecCommandEnd(ExecCommandEndEvent),
444 ExecApprovalRequest(ExecApprovalRequestEvent),
445 #[serde(other)]
446 Other,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
450pub struct AgentMessageEvent {
451 pub message: String,
452}
453
454#[derive(Debug, Clone, Deserialize, Serialize)]
455pub struct AgentReasoningEvent {
456 pub text: String,
457}
458
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct McpToolCallBeginEvent {
461 pub call_id: String,
462 pub server: String,
463 pub tool: String,
464 pub arguments: Option<serde_json::Value>,
465}
466
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct McpToolCallEndEvent {
469 pub call_id: String,
470 pub result: Result<serde_json::Value, String>,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct ExecCommandBeginEvent {
475 pub call_id: String,
476 pub command: Vec<String>,
477 pub cwd: PathBuf,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct ExecCommandEndEvent {
482 pub call_id: String,
483 pub stdout: String,
484 pub stderr: String,
485 pub exit_code: i32,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct ExecApprovalRequestEvent {
490 pub call_id: String,
491 pub command: Vec<String>,
492 pub cwd: PathBuf,
493 #[serde(skip_serializing_if = "Option::is_none")]
494 pub reason: Option<String>,
495}
496
497// Helper functions
498fn md_codeblock(lang: &str, content: &str) -> String {
499 if content.ends_with('\n') {
500 format!("```{}\n{}```", lang, content)
501 } else {
502 format!("```{}\n{}\n```", lang, content)
503 }
504}
505
506fn strip_bash_lc_and_escape(command: &[String]) -> String {
507 match command {
508 // exactly three items
509 [first, second, third]
510 // first two must be "bash", "-lc"
511 if first == "bash" && second == "-lc" =>
512 {
513 third.clone()
514 }
515 _ => escape_command(command),
516 }
517}
518
519fn escape_command(command: &[String]) -> String {
520 shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
521}
522
523fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
524 let mut content = String::new();
525
526 for chunk in chunks {
527 match chunk {
528 ToolResponseContent::Text { text } => content.push_str(&text),
529 ToolResponseContent::Image { .. } => {
530 // todo!
531 }
532 ToolResponseContent::Audio { .. } => {
533 // todo!
534 }
535 ToolResponseContent::Resource { .. } => {
536 // todo!
537 }
538 }
539 }
540
541 if !content.is_empty() {
542 Some(acp::ToolCallContent::Markdown { markdown: content })
543 } else {
544 None
545 }
546}
547
548#[cfg(test)]
549pub mod tests {
550 use super::*;
551
552 pub fn local_command() -> AgentServerCommand {
553 let cli_path =
554 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../codex/target/debug/codex");
555
556 AgentServerCommand {
557 path: cli_path,
558 args: vec!["mcp".into()],
559 env: None,
560 }
561 }
562}