1use agent_client_protocol as acp;
2use anyhow::anyhow;
3use collections::HashMap;
4use context_server::listener::McpServerTool;
5use context_server::types::requests;
6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
7use futures::channel::{mpsc, oneshot};
8use project::Project;
9use settings::SettingsStore;
10use smol::stream::StreamExt as _;
11use std::cell::RefCell;
12use std::rc::Rc;
13use std::{path::Path, sync::Arc};
14use util::ResultExt;
15
16use anyhow::{Context, Result};
17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
18
19use crate::mcp_server::ZedMcpServer;
20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
21use acp_thread::{AcpThread, AgentConnection};
22
23#[derive(Clone)]
24pub struct Codex;
25
26impl AgentServer for Codex {
27 fn name(&self) -> &'static str {
28 "Codex"
29 }
30
31 fn empty_state_headline(&self) -> &'static str {
32 "Welcome to Codex"
33 }
34
35 fn empty_state_message(&self) -> &'static str {
36 "What can I help with?"
37 }
38
39 fn logo(&self) -> ui::IconName {
40 ui::IconName::AiOpenAi
41 }
42
43 fn connect(
44 &self,
45 _root_dir: &Path,
46 project: &Entity<Project>,
47 cx: &mut App,
48 ) -> Task<Result<Rc<dyn AgentConnection>>> {
49 let project = project.clone();
50 cx.spawn(async move |cx| {
51 let settings = cx.read_global(|settings: &SettingsStore, _| {
52 settings.get::<AllAgentServersSettings>(None).codex.clone()
53 })?;
54
55 let Some(command) =
56 AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
57 else {
58 anyhow::bail!("Failed to find codex binary");
59 };
60
61 let client: Arc<ContextServer> = ContextServer::stdio(
62 ContextServerId("codex-mcp-server".into()),
63 ContextServerCommand {
64 path: command.path,
65 args: command.args,
66 env: command.env,
67 },
68 )
69 .into();
70 ContextServer::start(client.clone(), cx).await?;
71
72 let (notification_tx, mut notification_rx) = mpsc::unbounded();
73 client
74 .client()
75 .context("Failed to subscribe")?
76 .on_notification(acp::AGENT_METHODS.session_update, {
77 move |notification, _cx| {
78 let notification_tx = notification_tx.clone();
79 log::trace!(
80 "ACP Notification: {}",
81 serde_json::to_string_pretty(¬ification).unwrap()
82 );
83
84 if let Some(notification) =
85 serde_json::from_value::<acp::SessionNotification>(notification)
86 .log_err()
87 {
88 notification_tx.unbounded_send(notification).ok();
89 }
90 }
91 });
92
93 let sessions = Rc::new(RefCell::new(HashMap::default()));
94
95 let notification_handler_task = cx.spawn({
96 let sessions = sessions.clone();
97 async move |cx| {
98 while let Some(notification) = notification_rx.next().await {
99 CodexConnection::handle_session_notification(
100 notification,
101 sessions.clone(),
102 cx,
103 )
104 }
105 }
106 });
107
108 let connection = CodexConnection {
109 client,
110 sessions,
111 _notification_handler_task: notification_handler_task,
112 };
113 Ok(Rc::new(connection) as _)
114 })
115 }
116}
117
118struct CodexConnection {
119 client: Arc<context_server::ContextServer>,
120 sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
121 _notification_handler_task: Task<()>,
122}
123
124struct CodexSession {
125 thread: WeakEntity<AcpThread>,
126 cancel_tx: Option<oneshot::Sender<()>>,
127 _mcp_server: ZedMcpServer,
128}
129
130impl AgentConnection for CodexConnection {
131 fn name(&self) -> &'static str {
132 "Codex"
133 }
134
135 fn new_thread(
136 self: Rc<Self>,
137 project: Entity<Project>,
138 cwd: &Path,
139 cx: &mut AsyncApp,
140 ) -> Task<Result<Entity<AcpThread>>> {
141 let client = self.client.client();
142 let sessions = self.sessions.clone();
143 let cwd = cwd.to_path_buf();
144 cx.spawn(async move |cx| {
145 let client = client.context("MCP server is not initialized yet")?;
146 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
147
148 let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
149
150 let response = client
151 .request::<requests::CallTool>(context_server::types::CallToolParams {
152 name: acp::AGENT_METHODS.new_session.into(),
153 arguments: Some(serde_json::to_value(acp::NewSessionArguments {
154 mcp_servers: vec![mcp_server.server_config()?],
155 client_tools: acp::ClientTools {
156 request_permission: Some(acp::McpToolId {
157 mcp_server: mcp_server::SERVER_NAME.into(),
158 tool_name: mcp_server::RequestPermissionTool::NAME.into(),
159 }),
160 read_text_file: Some(acp::McpToolId {
161 mcp_server: mcp_server::SERVER_NAME.into(),
162 tool_name: mcp_server::ReadTextFileTool::NAME.into(),
163 }),
164 write_text_file: Some(acp::McpToolId {
165 mcp_server: mcp_server::SERVER_NAME.into(),
166 tool_name: mcp_server::WriteTextFileTool::NAME.into(),
167 }),
168 },
169 cwd,
170 })?),
171 meta: None,
172 })
173 .await?;
174
175 if response.is_error.unwrap_or_default() {
176 return Err(anyhow!(response.text_contents()));
177 }
178
179 let result = serde_json::from_value::<acp::NewSessionOutput>(
180 response.structured_content.context("Empty response")?,
181 )?;
182
183 let thread =
184 cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
185
186 thread_tx.send(thread.downgrade())?;
187
188 let session = CodexSession {
189 thread: thread.downgrade(),
190 cancel_tx: None,
191 _mcp_server: mcp_server,
192 };
193 sessions.borrow_mut().insert(result.session_id, session);
194
195 Ok(thread)
196 })
197 }
198
199 fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
200 Task::ready(Err(anyhow!("Authentication not supported")))
201 }
202
203 fn prompt(
204 &self,
205 params: agent_client_protocol::PromptArguments,
206 cx: &mut App,
207 ) -> Task<Result<()>> {
208 let client = self.client.client();
209 let sessions = self.sessions.clone();
210
211 cx.foreground_executor().spawn(async move {
212 let client = client.context("MCP server is not initialized yet")?;
213
214 let (new_cancel_tx, cancel_rx) = oneshot::channel();
215 {
216 let mut sessions = sessions.borrow_mut();
217 let session = sessions
218 .get_mut(¶ms.session_id)
219 .context("Session not found")?;
220 session.cancel_tx.replace(new_cancel_tx);
221 }
222
223 let result = client
224 .request_with::<requests::CallTool>(
225 context_server::types::CallToolParams {
226 name: acp::AGENT_METHODS.prompt.into(),
227 arguments: Some(serde_json::to_value(params)?),
228 meta: None,
229 },
230 Some(cancel_rx),
231 None,
232 )
233 .await;
234
235 if let Err(err) = &result
236 && err.is::<context_server::client::RequestCanceled>()
237 {
238 return Ok(());
239 }
240
241 let response = result?;
242
243 if response.is_error.unwrap_or_default() {
244 return Err(anyhow!(response.text_contents()));
245 }
246
247 Ok(())
248 })
249 }
250
251 fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
252 let mut sessions = self.sessions.borrow_mut();
253
254 if let Some(cancel_tx) = sessions
255 .get_mut(session_id)
256 .and_then(|session| session.cancel_tx.take())
257 {
258 cancel_tx.send(()).ok();
259 }
260 }
261}
262
263impl CodexConnection {
264 pub fn handle_session_notification(
265 notification: acp::SessionNotification,
266 threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
267 cx: &mut AsyncApp,
268 ) {
269 let threads = threads.borrow();
270 let Some(thread) = threads
271 .get(¬ification.session_id)
272 .and_then(|session| session.thread.upgrade())
273 else {
274 log::error!(
275 "Thread not found for session ID: {}",
276 notification.session_id
277 );
278 return;
279 };
280
281 thread
282 .update(cx, |thread, cx| {
283 thread.handle_session_update(notification.update, cx)
284 })
285 .log_err();
286 }
287}
288
289impl Drop for CodexConnection {
290 fn drop(&mut self) {
291 self.client.stop().log_err();
292 }
293}
294
295#[cfg(test)]
296pub(crate) mod tests {
297 use super::*;
298 use crate::AgentServerCommand;
299 use std::path::Path;
300
301 crate::common_e2e_tests!(Codex, allow_option_id = "approve");
302
303 pub fn local_command() -> AgentServerCommand {
304 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
305 .join("../../../codex/codex-rs/target/debug/codex");
306
307 AgentServerCommand {
308 path: cli_path,
309 args: vec!["mcp".into()],
310 env: None,
311 }
312 }
313}