1use agent_client_protocol as acp;
2use anyhow::anyhow;
3use collections::HashMap;
4use context_server::listener::McpServerTool;
5use context_server::types::requests;
6use context_server::{ContextServer, ContextServerCommand, ContextServerId};
7use futures::channel::{mpsc, oneshot};
8use project::Project;
9use settings::SettingsStore;
10use smol::stream::StreamExt as _;
11use std::cell::RefCell;
12use std::rc::Rc;
13use std::{path::Path, sync::Arc};
14use util::ResultExt;
15
16use anyhow::{Context, Result};
17use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
18
19use crate::mcp_server::ZedMcpServer;
20use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
21use acp_thread::{AcpThread, AgentConnection};
22
23#[derive(Clone)]
24pub struct Codex;
25
26impl AgentServer for Codex {
27 fn name(&self) -> &'static str {
28 "Codex"
29 }
30
31 fn empty_state_headline(&self) -> &'static str {
32 "Welcome to Codex"
33 }
34
35 fn empty_state_message(&self) -> &'static str {
36 "What can I help with?"
37 }
38
39 fn logo(&self) -> ui::IconName {
40 ui::IconName::AiOpenAi
41 }
42
43 fn connect(
44 &self,
45 _root_dir: &Path,
46 project: &Entity<Project>,
47 cx: &mut App,
48 ) -> Task<Result<Rc<dyn AgentConnection>>> {
49 let project = project.clone();
50 let working_directory = project.read(cx).active_project_directory(cx);
51 cx.spawn(async move |cx| {
52 let settings = cx.read_global(|settings: &SettingsStore, _| {
53 settings.get::<AllAgentServersSettings>(None).codex.clone()
54 })?;
55
56 let Some(command) =
57 AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
58 else {
59 anyhow::bail!("Failed to find codex binary");
60 };
61
62 let client: Arc<ContextServer> = ContextServer::stdio(
63 ContextServerId("codex-mcp-server".into()),
64 ContextServerCommand {
65 path: command.path,
66 args: command.args,
67 env: command.env,
68 },
69 working_directory,
70 )
71 .into();
72 ContextServer::start(client.clone(), cx).await?;
73
74 let (notification_tx, mut notification_rx) = mpsc::unbounded();
75 client
76 .client()
77 .context("Failed to subscribe")?
78 .on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
79 move |notification, _cx| {
80 let notification_tx = notification_tx.clone();
81 log::trace!(
82 "ACP Notification: {}",
83 serde_json::to_string_pretty(¬ification).unwrap()
84 );
85
86 if let Some(notification) =
87 serde_json::from_value::<acp::SessionNotification>(notification)
88 .log_err()
89 {
90 notification_tx.unbounded_send(notification).ok();
91 }
92 }
93 });
94
95 let sessions = Rc::new(RefCell::new(HashMap::default()));
96
97 let notification_handler_task = cx.spawn({
98 let sessions = sessions.clone();
99 async move |cx| {
100 while let Some(notification) = notification_rx.next().await {
101 CodexConnection::handle_session_notification(
102 notification,
103 sessions.clone(),
104 cx,
105 )
106 }
107 }
108 });
109
110 let connection = CodexConnection {
111 client,
112 sessions,
113 _notification_handler_task: notification_handler_task,
114 };
115 Ok(Rc::new(connection) as _)
116 })
117 }
118}
119
120struct CodexConnection {
121 client: Arc<context_server::ContextServer>,
122 sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
123 _notification_handler_task: Task<()>,
124}
125
126struct CodexSession {
127 thread: WeakEntity<AcpThread>,
128 cancel_tx: Option<oneshot::Sender<()>>,
129 _mcp_server: ZedMcpServer,
130}
131
132impl AgentConnection for CodexConnection {
133 fn name(&self) -> &'static str {
134 "Codex"
135 }
136
137 fn new_thread(
138 self: Rc<Self>,
139 project: Entity<Project>,
140 cwd: &Path,
141 cx: &mut AsyncApp,
142 ) -> Task<Result<Entity<AcpThread>>> {
143 let client = self.client.client();
144 let sessions = self.sessions.clone();
145 let cwd = cwd.to_path_buf();
146 cx.spawn(async move |cx| {
147 let client = client.context("MCP server is not initialized yet")?;
148 let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
149
150 let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
151
152 let response = client
153 .request::<requests::CallTool>(context_server::types::CallToolParams {
154 name: acp::NEW_SESSION_TOOL_NAME.into(),
155 arguments: Some(serde_json::to_value(acp::NewSessionArguments {
156 mcp_servers: [(
157 mcp_server::SERVER_NAME.to_string(),
158 mcp_server.server_config()?,
159 )]
160 .into(),
161 client_tools: acp::ClientTools {
162 request_permission: Some(acp::McpToolId {
163 mcp_server: mcp_server::SERVER_NAME.into(),
164 tool_name: mcp_server::RequestPermissionTool::NAME.into(),
165 }),
166 read_text_file: Some(acp::McpToolId {
167 mcp_server: mcp_server::SERVER_NAME.into(),
168 tool_name: mcp_server::ReadTextFileTool::NAME.into(),
169 }),
170 write_text_file: Some(acp::McpToolId {
171 mcp_server: mcp_server::SERVER_NAME.into(),
172 tool_name: mcp_server::WriteTextFileTool::NAME.into(),
173 }),
174 },
175 cwd,
176 })?),
177 meta: None,
178 })
179 .await?;
180
181 if response.is_error.unwrap_or_default() {
182 return Err(anyhow!(response.text_contents()));
183 }
184
185 let result = serde_json::from_value::<acp::NewSessionOutput>(
186 response.structured_content.context("Empty response")?,
187 )?;
188
189 let thread =
190 cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
191
192 thread_tx.send(thread.downgrade())?;
193
194 let session = CodexSession {
195 thread: thread.downgrade(),
196 cancel_tx: None,
197 _mcp_server: mcp_server,
198 };
199 sessions.borrow_mut().insert(result.session_id, session);
200
201 Ok(thread)
202 })
203 }
204
205 fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
206 Task::ready(Err(anyhow!("Authentication not supported")))
207 }
208
209 fn prompt(
210 &self,
211 params: agent_client_protocol::PromptArguments,
212 cx: &mut App,
213 ) -> Task<Result<()>> {
214 let client = self.client.client();
215 let sessions = self.sessions.clone();
216
217 cx.foreground_executor().spawn(async move {
218 let client = client.context("MCP server is not initialized yet")?;
219
220 let (new_cancel_tx, cancel_rx) = oneshot::channel();
221 {
222 let mut sessions = sessions.borrow_mut();
223 let session = sessions
224 .get_mut(¶ms.session_id)
225 .context("Session not found")?;
226 session.cancel_tx.replace(new_cancel_tx);
227 }
228
229 let result = client
230 .request_with::<requests::CallTool>(
231 context_server::types::CallToolParams {
232 name: acp::PROMPT_TOOL_NAME.into(),
233 arguments: Some(serde_json::to_value(params)?),
234 meta: None,
235 },
236 Some(cancel_rx),
237 None,
238 )
239 .await;
240
241 if let Err(err) = &result
242 && err.is::<context_server::client::RequestCanceled>()
243 {
244 return Ok(());
245 }
246
247 let response = result?;
248
249 if response.is_error.unwrap_or_default() {
250 return Err(anyhow!(response.text_contents()));
251 }
252
253 Ok(())
254 })
255 }
256
257 fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
258 let mut sessions = self.sessions.borrow_mut();
259
260 if let Some(cancel_tx) = sessions
261 .get_mut(session_id)
262 .and_then(|session| session.cancel_tx.take())
263 {
264 cancel_tx.send(()).ok();
265 }
266 }
267}
268
269impl CodexConnection {
270 pub fn handle_session_notification(
271 notification: acp::SessionNotification,
272 threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
273 cx: &mut AsyncApp,
274 ) {
275 let threads = threads.borrow();
276 let Some(thread) = threads
277 .get(¬ification.session_id)
278 .and_then(|session| session.thread.upgrade())
279 else {
280 log::error!(
281 "Thread not found for session ID: {}",
282 notification.session_id
283 );
284 return;
285 };
286
287 thread
288 .update(cx, |thread, cx| {
289 thread.handle_session_update(notification.update, cx)
290 })
291 .log_err();
292 }
293}
294
295impl Drop for CodexConnection {
296 fn drop(&mut self) {
297 self.client.stop().log_err();
298 }
299}
300
301#[cfg(test)]
302pub(crate) mod tests {
303 use super::*;
304 use crate::AgentServerCommand;
305 use std::path::Path;
306
307 crate::common_e2e_tests!(Codex, allow_option_id = "approve");
308
309 pub fn local_command() -> AgentServerCommand {
310 let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
311 .join("../../../codex/codex-rs/target/debug/codex");
312
313 AgentServerCommand {
314 path: cli_path,
315 args: vec![],
316 env: None,
317 }
318 }
319}