1use agent_client_protocol as acp;
2use anyhow::anyhow;
3use collections::HashMap;
4use context_server::listener::McpServerTool;
5use context_server::types::requests;
6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
7use futures::channel::{mpsc, oneshot};
8use project::Project;
9use settings::SettingsStore;
10use smol::stream::StreamExt as _;
11use std::cell::RefCell;
12use std::rc::Rc;
13use std::{path::Path, sync::Arc};
14use util::ResultExt;
15
16use anyhow::{Context, Result};
17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
18
19use crate::mcp_server::ZedMcpServer;
20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
21use acp_thread::{AcpThread, AgentConnection};
22
23#[derive(Clone)]
24pub struct Codex;
25
26impl AgentServer for Codex {
27 fn name(&self) -> &'static str {
28 "Codex"
29 }
30
31 fn empty_state_headline(&self) -> &'static str {
32 "Welcome to Codex"
33 }
34
35 fn empty_state_message(&self) -> &'static str {
36 "What can I help with?"
37 }
38
39 fn logo(&self) -> ui::IconName {
40 ui::IconName::AiOpenAi
41 }
42
43 fn connect(
44 &self,
45 _root_dir: &Path,
46 project: &Entity<Project>,
47 cx: &mut App,
48 ) -> Task<Result<Rc<dyn AgentConnection>>> {
49 let project = project.clone();
50 cx.spawn(async move |cx| {
51 let settings = cx.read_global(|settings: &SettingsStore, _| {
52 settings.get::<AllAgentServersSettings>(None).codex.clone()
53 })?;
54
55 let Some(command) =
56 AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
57 else {
58 anyhow::bail!("Failed to find codex binary");
59 };
60
61 let client: Arc<ContextServer> = ContextServer::stdio(
62 ContextServerId("codex-mcp-server".into()),
63 ContextServerCommand {
64 path: command.path,
65 args: command.args,
66 env: command.env,
67 },
68 )
69 .into();
70 ContextServer::start(client.clone(), cx).await?;
71
72 let (notification_tx, mut notification_rx) = mpsc::unbounded();
73 client
74 .client()
75 .context("Failed to subscribe")?
76 .on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
77 move |notification, _cx| {
78 let notification_tx = notification_tx.clone();
79 log::trace!(
80 "ACP Notification: {}",
81 serde_json::to_string_pretty(¬ification).unwrap()
82 );
83
84 if let Some(notification) =
85 serde_json::from_value::<acp::SessionNotification>(notification)
86 .log_err()
87 {
88 notification_tx.unbounded_send(notification).ok();
89 }
90 }
91 });
92
93 let sessions = Rc::new(RefCell::new(HashMap::default()));
94
95 let notification_handler_task = cx.spawn({
96 let sessions = sessions.clone();
97 async move |cx| {
98 while let Some(notification) = notification_rx.next().await {
99 CodexConnection::handle_session_notification(
100 notification,
101 sessions.clone(),
102 cx,
103 )
104 }
105 }
106 });
107
108 let connection = CodexConnection {
109 client,
110 sessions,
111 _notification_handler_task: notification_handler_task,
112 };
113 Ok(Rc::new(connection) as _)
114 })
115 }
116}
117
118struct CodexConnection {
119 client: Arc<context_server::ContextServer>,
120 sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
121 _notification_handler_task: Task<()>,
122}
123
124struct CodexSession {
125 thread: WeakEntity<AcpThread>,
126 cancel_tx: Option<oneshot::Sender<()>>,
127 _mcp_server: ZedMcpServer,
128}
129
130impl AgentConnection for CodexConnection {
131 fn name(&self) -> &'static str {
132 "Codex"
133 }
134
135 fn new_thread(
136 self: Rc<Self>,
137 project: Entity<Project>,
138 cwd: &Path,
139 cx: &mut AsyncApp,
140 ) -> Task<Result<Entity<AcpThread>>> {
141 let client = self.client.client();
142 let sessions = self.sessions.clone();
143 let cwd = cwd.to_path_buf();
144 cx.spawn(async move |cx| {
145 let client = client.context("MCP server is not initialized yet")?;
146 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
147
148 let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
149
150 let response = client
151 .request::<requests::CallTool>(context_server::types::CallToolParams {
152 name: acp::NEW_SESSION_TOOL_NAME.into(),
153 arguments: Some(serde_json::to_value(acp::NewSessionArguments {
154 mcp_servers: [(
155 mcp_server::SERVER_NAME.to_string(),
156 mcp_server.server_config()?,
157 )]
158 .into(),
159 client_tools: acp::ClientTools {
160 request_permission: Some(acp::McpToolId {
161 mcp_server: mcp_server::SERVER_NAME.into(),
162 tool_name: mcp_server::RequestPermissionTool::NAME.into(),
163 }),
164 read_text_file: Some(acp::McpToolId {
165 mcp_server: mcp_server::SERVER_NAME.into(),
166 tool_name: mcp_server::ReadTextFileTool::NAME.into(),
167 }),
168 write_text_file: Some(acp::McpToolId {
169 mcp_server: mcp_server::SERVER_NAME.into(),
170 tool_name: mcp_server::WriteTextFileTool::NAME.into(),
171 }),
172 },
173 cwd,
174 })?),
175 meta: None,
176 })
177 .await?;
178
179 if response.is_error.unwrap_or_default() {
180 return Err(anyhow!(response.text_contents()));
181 }
182
183 let result = serde_json::from_value::<acp::NewSessionOutput>(
184 response.structured_content.context("Empty response")?,
185 )?;
186
187 let thread =
188 cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
189
190 thread_tx.send(thread.downgrade())?;
191
192 let session = CodexSession {
193 thread: thread.downgrade(),
194 cancel_tx: None,
195 _mcp_server: mcp_server,
196 };
197 sessions.borrow_mut().insert(result.session_id, session);
198
199 Ok(thread)
200 })
201 }
202
203 fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
204 Task::ready(Err(anyhow!("Authentication not supported")))
205 }
206
207 fn prompt(
208 &self,
209 params: agent_client_protocol::PromptArguments,
210 cx: &mut App,
211 ) -> Task<Result<()>> {
212 let client = self.client.client();
213 let sessions = self.sessions.clone();
214
215 cx.foreground_executor().spawn(async move {
216 let client = client.context("MCP server is not initialized yet")?;
217
218 let (new_cancel_tx, cancel_rx) = oneshot::channel();
219 {
220 let mut sessions = sessions.borrow_mut();
221 let session = sessions
222 .get_mut(¶ms.session_id)
223 .context("Session not found")?;
224 session.cancel_tx.replace(new_cancel_tx);
225 }
226
227 let result = client
228 .request_with::<requests::CallTool>(
229 context_server::types::CallToolParams {
230 name: acp::PROMPT_TOOL_NAME.into(),
231 arguments: Some(serde_json::to_value(params)?),
232 meta: None,
233 },
234 Some(cancel_rx),
235 None,
236 )
237 .await;
238
239 if let Err(err) = &result
240 && err.is::<context_server::client::RequestCanceled>()
241 {
242 return Ok(());
243 }
244
245 let response = result?;
246
247 if response.is_error.unwrap_or_default() {
248 return Err(anyhow!(response.text_contents()));
249 }
250
251 Ok(())
252 })
253 }
254
255 fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
256 let mut sessions = self.sessions.borrow_mut();
257
258 if let Some(cancel_tx) = sessions
259 .get_mut(session_id)
260 .and_then(|session| session.cancel_tx.take())
261 {
262 cancel_tx.send(()).ok();
263 }
264 }
265}
266
267impl CodexConnection {
268 pub fn handle_session_notification(
269 notification: acp::SessionNotification,
270 threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
271 cx: &mut AsyncApp,
272 ) {
273 let threads = threads.borrow();
274 let Some(thread) = threads
275 .get(¬ification.session_id)
276 .and_then(|session| session.thread.upgrade())
277 else {
278 log::error!(
279 "Thread not found for session ID: {}",
280 notification.session_id
281 );
282 return;
283 };
284
285 thread
286 .update(cx, |thread, cx| {
287 thread.handle_session_update(notification.update, cx)
288 })
289 .log_err();
290 }
291}
292
293impl Drop for CodexConnection {
294 fn drop(&mut self) {
295 self.client.stop().log_err();
296 }
297}
298
299#[cfg(test)]
300pub(crate) mod tests {
301 use super::*;
302 use crate::AgentServerCommand;
303 use std::path::Path;
304
305 crate::common_e2e_tests!(Codex);
306
307 pub fn local_command() -> AgentServerCommand {
308 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
309 .join("../../../codex/codex-rs/target/debug/codex");
310
311 AgentServerCommand {
312 path: cli_path,
313 args: vec!["mcp".into()],
314 env: None,
315 }
316 }
317}