Display tool calls

Agus Zubiaga created

Change summary

Cargo.lock                        |   1 
crates/agent_servers/Cargo.toml   |   1 
crates/agent_servers/src/codex.rs | 256 ++++++++++++++++++++++++++++++++
3 files changed, 254 insertions(+), 4 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -157,6 +157,7 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
+ "shlex",
  "smol",
  "strum 0.27.1",
  "tempfile",

crates/agent_servers/Cargo.toml 🔗

@@ -39,6 +39,7 @@ ui.workspace = true
 util.workspace = true
 watch.workspace = true
 which.workspace = true
+shlex.workspace = true
 workspace-hack.workspace = true
 
 [dev-dependencies]

crates/agent_servers/src/codex.rs 🔗

@@ -1,6 +1,6 @@
 use collections::HashMap;
-use context_server::types::CallToolParams;
 use context_server::types::requests::CallTool;
+use context_server::types::{CallToolParams, ToolResponseContent};
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use futures::channel::{mpsc, oneshot};
 use project::Project;
@@ -132,11 +132,16 @@ impl AgentServer for Codex {
 
                 let handler_task = cx.spawn({
                     let delegate = delegate.clone();
+                    let tool_id_map = tool_id_map.clone();
                     async move |_, _cx| {
                         while let Some(notification) = notification_rx.next().await {
-                            CodexAgentConnection::handle_acp_notification(&delegate, notification)
-                                .await
-                                .log_err();
+                            CodexAgentConnection::handle_acp_notification(
+                                &delegate,
+                                notification,
+                                &tool_id_map,
+                            )
+                            .await
+                            .log_err();
                         }
                     }
                 });
@@ -145,6 +150,7 @@ impl AgentServer for Codex {
                     root_dir,
                     codex_mcp: codex_mcp_client,
                     cancel_request_tx: Default::default(),
+                    tool_id_map: tool_id_map.clone(),
                     _handler_task: handler_task,
                     _zed_mcp: zed_mcp_server,
                 };
@@ -231,6 +237,7 @@ struct CodexAgentConnection {
     codex_mcp: Arc<context_server::ContextServer>,
     root_dir: PathBuf,
     cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
+    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
     _handler_task: Task<()>,
     _zed_mcp: ZedMcpServer,
 }
@@ -239,6 +246,7 @@ impl CodexAgentConnection {
     async fn handle_acp_notification(
         delegate: &AcpClientDelegate,
         event: AcpNotification,
+        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
     ) -> Result<()> {
         match event {
             AcpNotification::AgentMessage(message) => {
@@ -259,6 +267,152 @@ impl CodexAgentConnection {
                     })
                     .await?
             }
+            AcpNotification::McpToolCallBegin(event) => {
+                let result = delegate
+                    .push_tool_call(acp::PushToolCallParams {
+                        label: format!("`{}: {}`", event.server, event.tool),
+                        icon: acp::Icon::Hammer,
+                        content: event.arguments.and_then(|args| {
+                            Some(acp::ToolCallContent::Markdown {
+                                markdown: md_codeblock(
+                                    "json",
+                                    &serde_json::to_string_pretty(&args).ok()?,
+                                ),
+                            })
+                        }),
+                        locations: vec![],
+                    })
+                    .await?;
+
+                tool_id_map.borrow_mut().insert(event.call_id, result.id);
+            }
+            AcpNotification::McpToolCallEnd(event) => {
+                let acp_call_id = tool_id_map
+                    .borrow_mut()
+                    .remove(&event.call_id)
+                    .context("Missing tool call")?;
+
+                let (status, content) = match event.result {
+                    Ok(value) => {
+                        if let Ok(response) =
+                            serde_json::from_value::<context_server::types::CallToolResponse>(value)
+                        {
+                            (
+                                acp::ToolCallStatus::Finished,
+                                mcp_tool_content_to_acp(response.content),
+                            )
+                        } else {
+                            (
+                                acp::ToolCallStatus::Error,
+                                Some(acp::ToolCallContent::Markdown {
+                                    markdown: "Failed to parse tool response".to_string(),
+                                }),
+                            )
+                        }
+                    }
+                    Err(error) => (
+                        acp::ToolCallStatus::Error,
+                        Some(acp::ToolCallContent::Markdown { markdown: error }),
+                    ),
+                };
+
+                delegate
+                    .update_tool_call(acp::UpdateToolCallParams {
+                        tool_call_id: acp_call_id,
+                        status,
+                        content,
+                    })
+                    .await?;
+            }
+            AcpNotification::ExecCommandBegin(event) => {
+                let inner_command = strip_bash_lc_and_escape(&event.command);
+
+                let result = delegate
+                    .push_tool_call(acp::PushToolCallParams {
+                        label: format!("`{}`", inner_command),
+                        icon: acp::Icon::Terminal,
+                        content: None,
+                        locations: vec![],
+                    })
+                    .await?;
+
+                tool_id_map.borrow_mut().insert(event.call_id, result.id);
+            }
+            AcpNotification::ExecCommandEnd(event) => {
+                let acp_call_id = tool_id_map
+                    .borrow_mut()
+                    .remove(&event.call_id)
+                    .context("Missing tool call")?;
+
+                let mut content = String::new();
+                if !event.stdout.is_empty() {
+                    use std::fmt::Write;
+                    writeln!(
+                        &mut content,
+                        "### Output\n\n{}",
+                        md_codeblock("", &event.stdout)
+                    )
+                    .unwrap();
+                }
+                if !event.stdout.is_empty() && !event.stderr.is_empty() {
+                    use std::fmt::Write;
+                    writeln!(&mut content).unwrap();
+                }
+                if !event.stderr.is_empty() {
+                    use std::fmt::Write;
+                    writeln!(
+                        &mut content,
+                        "### Error\n\n{}",
+                        md_codeblock("", &event.stderr)
+                    )
+                    .unwrap();
+                }
+                let success = event.exit_code == 0;
+                if !success {
+                    use std::fmt::Write;
+                    writeln!(&mut content, "\nExit code: `{}`", event.exit_code).unwrap();
+                }
+
+                delegate
+                    .update_tool_call(acp::UpdateToolCallParams {
+                        tool_call_id: acp_call_id,
+                        status: if success {
+                            acp::ToolCallStatus::Finished
+                        } else {
+                            acp::ToolCallStatus::Error
+                        },
+                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
+                    })
+                    .await?;
+            }
+            AcpNotification::ExecApprovalRequest(event) => {
+                let inner_command = strip_bash_lc_and_escape(&event.command);
+                let root_command = inner_command
+                    .split(" ")
+                    .next()
+                    .map(|s| s.to_string())
+                    .unwrap_or_default();
+
+                let response = delegate
+                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
+                        tool_call: acp::PushToolCallParams {
+                            label: format!("`{}`", inner_command),
+                            icon: acp::Icon::Terminal,
+                            content: None,
+                            locations: vec![],
+                        },
+                        confirmation: acp::ToolCallConfirmation::Execute {
+                            command: inner_command,
+                            root_command,
+                            description: event.reason,
+                        },
+                    })
+                    .await?;
+
+                tool_id_map.borrow_mut().insert(event.call_id, response.id);
+
+                // todo! approval
+            }
             AcpNotification::Other => {}
         }
 
@@ -285,6 +439,11 @@ struct CodexEvent {
 pub enum AcpNotification {
     AgentMessage(AgentMessageEvent),
     AgentReasoning(AgentReasoningEvent),
+    McpToolCallBegin(McpToolCallBeginEvent),
+    McpToolCallEnd(McpToolCallEndEvent),
+    ExecCommandBegin(ExecCommandBeginEvent),
+    ExecCommandEnd(ExecCommandEndEvent),
+    ExecApprovalRequest(ExecApprovalRequestEvent),
     #[serde(other)]
     Other,
 }
@@ -298,3 +457,92 @@ pub struct AgentMessageEvent {
 pub struct AgentReasoningEvent {
     pub text: String,
 }
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct McpToolCallBeginEvent {
+    pub call_id: String,
+    pub server: String,
+    pub tool: String,
+    pub arguments: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct McpToolCallEndEvent {
+    pub call_id: String,
+    pub result: Result<serde_json::Value, String>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ExecCommandBeginEvent {
+    pub call_id: String,
+    pub command: Vec<String>,
+    pub cwd: PathBuf,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ExecCommandEndEvent {
+    pub call_id: String,
+    pub stdout: String,
+    pub stderr: String,
+    pub exit_code: i32,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ExecApprovalRequestEvent {
+    pub call_id: String,
+    pub command: Vec<String>,
+    pub cwd: PathBuf,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub reason: Option<String>,
+}
+
+// Helper functions
+fn md_codeblock(lang: &str, content: &str) -> String {
+    if content.ends_with('\n') {
+        format!("```{}\n{}```", lang, content)
+    } else {
+        format!("```{}\n{}\n```", lang, content)
+    }
+}
+
+fn strip_bash_lc_and_escape(command: &[String]) -> String {
+    match command {
+        // exactly three items
+        [first, second, third]
+            // first two must be "bash", "-lc"
+            if first == "bash" && second == "-lc" =>
+        {
+            third.clone()
+        }
+        _ => escape_command(command),
+    }
+}
+
+fn escape_command(command: &[String]) -> String {
+    shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
+}
+
+fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
+    let mut content = String::new();
+
+    for chunk in chunks {
+        match chunk {
+            ToolResponseContent::Text { text } => content.push_str(&text),
+            ToolResponseContent::Image { .. } => {
+                // todo!
+            }
+            ToolResponseContent::Audio { .. } => {
+                // todo!
+            }
+            ToolResponseContent::Resource { .. } => {
+                // todo!
+            }
+        }
+    }
+
+    if !content.is_empty() {
+        Some(acp::ToolCallContent::Markdown { markdown: content })
+    } else {
+        None
+    }
+}