Fix tool call confirmation

Agus Zubiaga created

Change summary

crates/acp_thread/src/acp_thread.rs    | 200 +++++++++++++--------------
crates/agent_servers/src/codex.rs      |  96 ++++++------
crates/agent_servers/src/mcp_server.rs |  41 ++--
3 files changed, 168 insertions(+), 169 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -2,7 +2,7 @@ mod connection;
 pub use connection::*;
 
 use agent_client_protocol as acp;
-use agentic_coding_protocol::{self as acp_old, ToolCallConfirmationOutcome};
+use agentic_coding_protocol as acp_old;
 use anyhow::{Context as _, Result};
 use assistant_tool::ActionLog;
 use buffer_diff::BufferDiff;
@@ -715,10 +715,19 @@ impl AcpThread {
         tool_call: acp::ToolCall,
         cx: &mut Context<Self>,
     ) -> Result<()> {
-        let language_registry = self.project.read(cx).languages().clone();
         let status = ToolCallStatus::Allowed {
             status: tool_call.status,
         };
+        self.update_tool_call_inner(tool_call, status, cx)
+    }
+
+    pub fn update_tool_call_inner(
+        &mut self,
+        tool_call: acp::ToolCall,
+        status: ToolCallStatus,
+        cx: &mut Context<Self>,
+    ) -> Result<()> {
+        let language_registry = self.project.read(cx).languages().clone();
         let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 
         let location = call.locations.last().cloned();
@@ -795,27 +804,10 @@ impl AcpThread {
             respond_tx: tx,
         };
 
-        self.insert_tool_call(tool_call, status, cx);
+        self.update_tool_call_inner(tool_call, status, cx);
         rx
     }
 
-    fn insert_tool_call(
-        &mut self,
-        tool_call: acp::ToolCall,
-        status: ToolCallStatus,
-        cx: &mut Context<Self>,
-    ) {
-        let language_registry = self.project.read(cx).languages().clone();
-        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
-
-        let location = call.locations.last().cloned();
-        if let Some(location) = location {
-            self.set_project_location(location, cx)
-        }
-
-        self.push_entry(AgentThreadEntry::ToolCall(call), cx);
-    }
-
     pub fn authorize_tool_call(
         &mut self,
         id: acp::ToolCallId,
@@ -1194,80 +1186,6 @@ impl OldAcpClientDelegate {
         Ok(())
     }
 
-    pub async fn request_existing_tool_call_confirmation(
-        &self,
-        tool_call_id: acp_old::ToolCallId,
-        confirmation: acp_old::ToolCallConfirmation,
-    ) -> Result<acp_old::ToolCallConfirmationOutcome> {
-        let cx = &mut self.cx.clone();
-
-        let tool_call = into_new_tool_call(
-            acp::ToolCallId(tool_call_id.0.to_string().into()),
-            confirmation,
-        );
-
-        let options = [
-            acp_old::ToolCallConfirmationOutcome::Allow,
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
-            acp_old::ToolCallConfirmationOutcome::Reject,
-        ]
-        .into_iter()
-        .map(|outcome| match outcome {
-            acp_old::ToolCallConfirmationOutcome::Allow => acp::PermissionOption {
-                id: acp::PermissionOptionId(Arc::from("allow")),
-                label: "Allow".to_string(),
-                kind: acp::PermissionOptionKind::AllowOnce,
-            },
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllow => acp::PermissionOption {
-                id: acp::PermissionOptionId(Arc::from("always-allow")),
-                label: "Always Allow".to_string(),
-                kind: acp::PermissionOptionKind::AllowOnce,
-            },
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer => acp::PermissionOption {
-                id: acp::PermissionOptionId(Arc::from("always-allow-mcp-server")),
-                label: "Always Allow MCP Server".to_string(),
-                kind: acp::PermissionOptionKind::AllowOnce,
-            },
-            acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool => acp::PermissionOption {
-                id: acp::PermissionOptionId(Arc::from("always-allow-tool")),
-                label: "Always Allow Tool".to_string(),
-                kind: acp::PermissionOptionKind::AllowOnce,
-            },
-            acp_old::ToolCallConfirmationOutcome::Reject => acp::PermissionOption {
-                id: acp::PermissionOptionId(Arc::from("reject")),
-                label: "Reject".to_string(),
-                kind: acp::PermissionOptionKind::AllowOnce,
-            },
-            acp_old::ToolCallConfirmationOutcome::Cancel => unreachable!(),
-        })
-        .collect();
-
-        let outcome = cx
-            .update(|cx| {
-                self.thread.update(cx, |thread, cx| {
-                    thread.request_tool_call_permission(tool_call, options, cx)
-                })
-            })?
-            .context("Failed to update thread")?;
-
-        let response = match outcome.await {
-            Ok(result) => match result.0.as_ref() {
-                "allow" => acp_old::ToolCallConfirmationOutcome::Allow,
-                "always-allow" => acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
-                "always-allow-tool" => acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
-                "always-allow-mcp-server" => {
-                    acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
-                }
-                "reject" => acp_old::ToolCallConfirmationOutcome::Reject,
-            },
-            Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
-        };
-
-        Ok(response)
-    }
-
     pub async fn read_text_file_reusing_snapshot(
         &self,
         request: acp_old::ReadTextFileParams,
@@ -1320,16 +1238,96 @@ impl acp_old::Client for OldAcpClientDelegate {
         request: acp_old::RequestToolCallConfirmationParams,
     ) -> Result<acp_old::RequestToolCallConfirmationResponse, acp_old::Error> {
         let cx = &mut self.cx.clone();
-        let ToolCallRequest { id, outcome } = cx
+
+        let old_acp_id = *self.next_tool_call_id.borrow() + 1;
+        self.next_tool_call_id.replace(old_acp_id);
+
+        let tool_call = into_new_tool_call(
+            acp::ToolCallId(old_acp_id.to_string().into()),
+            request.tool_call,
+        );
+
+        let mut options = match request.confirmation {
+            acp_old::ToolCallConfirmation::Edit { .. } => vec![(
+                acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
+                acp::PermissionOptionKind::AllowAlways,
+                "Always Allow Edits".to_string(),
+            )],
+            acp_old::ToolCallConfirmation::Execute { root_command, .. } => vec![(
+                acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
+                acp::PermissionOptionKind::AllowAlways,
+                format!("Always Allow {}", root_command),
+            )],
+            acp_old::ToolCallConfirmation::Mcp {
+                server_name,
+                tool_name,
+                ..
+            } => vec![
+                (
+                    acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer,
+                    acp::PermissionOptionKind::AllowAlways,
+                    format!("Always Allow {}", server_name),
+                ),
+                (
+                    acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool,
+                    acp::PermissionOptionKind::AllowAlways,
+                    format!("Always Allow {}", tool_name),
+                ),
+            ],
+            acp_old::ToolCallConfirmation::Fetch { .. } => vec![(
+                acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
+                acp::PermissionOptionKind::AllowAlways,
+                "Always Allow".to_string(),
+            )],
+            acp_old::ToolCallConfirmation::Other { .. } => vec![(
+                acp_old::ToolCallConfirmationOutcome::AlwaysAllow,
+                acp::PermissionOptionKind::AllowAlways,
+                "Always Allow".to_string(),
+            )],
+        };
+
+        options.extend([
+            (
+                acp_old::ToolCallConfirmationOutcome::Allow,
+                acp::PermissionOptionKind::AllowOnce,
+                "Allow".to_string(),
+            ),
+            (
+                acp_old::ToolCallConfirmationOutcome::Reject,
+                acp::PermissionOptionKind::RejectOnce,
+                "Reject".to_string(),
+            ),
+        ]);
+
+        let mut outcomes = Vec::with_capacity(options.len());
+        let mut acp_options = Vec::with_capacity(options.len());
+
+        for (index, (outcome, kind, label)) in options.into_iter().enumerate() {
+            outcomes.push(outcome);
+            acp_options.push(acp::PermissionOption {
+                id: acp::PermissionOptionId(index.to_string().into()),
+                label,
+                kind,
+            })
+        }
+
+        let response = cx
             .update(|cx| {
-                self.thread
-                    .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
+                self.thread.update(cx, |thread, cx| {
+                    thread.request_tool_call_permission(tool_call, acp_options, cx)
+                })
             })?
-            .context("Failed to update thread")?;
+            .context("Failed to update thread")?
+            .await;
+
+        let outcome = match response {
+            Ok(option_id) => outcomes[option_id.0.parse::<usize>().unwrap_or(0)].clone(),
+            Err(oneshot::Canceled) => acp_old::ToolCallConfirmationOutcome::Cancel,
+        };
 
         Ok(acp_old::RequestToolCallConfirmationResponse {
-            id,
-            outcome: outcome.await.map_err(acp_old::Error::into_internal_error)?,
+            id: acp_old::ToolCallId(old_acp_id),
+            outcome: outcome,
         })
     }
 
@@ -1350,7 +1348,7 @@ impl acp_old::Client for OldAcpClientDelegate {
                 )
             })
         })?
-        .context("Failed to update thread")?;
+        .context("Failed to update thread")??;
 
         Ok(acp_old::PushToolCallResponse {
             id: acp_old::ToolCallId(old_acp_id),

crates/agent_servers/src/codex.rs 🔗

@@ -247,14 +247,14 @@ impl AgentServer for Codex {
                                     let inner_command =
                                         strip_bash_lc_and_escape(&exec.codex_command);
 
-                                    acp::RequestToolCallConfirmationParams {
-                                        tool_call: acp::PushToolCallParams {
+                                    acp_old::RequestToolCallConfirmationParams {
+                                        tool_call: acp_old::PushToolCallParams {
                                             label: todo!(),
-                                            icon: acp::Icon::Terminal,
+                                            icon: acp_old::Icon::Terminal,
                                             content: None,
                                             locations: vec![],
                                         },
-                                        confirmation: acp::ToolCallConfirmation::Execute {
+                                        confirmation: acp_old::ToolCallConfirmation::Execute {
                                             root_command: inner_command
                                                 .split(" ")
                                                 .next()
@@ -266,21 +266,21 @@ impl AgentServer for Codex {
                                     }
                                 }
                                 CodexElicitation::PatchApproval(patch) => {
-                                    acp::RequestToolCallConfirmationParams {
-                                        tool_call: acp::PushToolCallParams {
+                                    acp_old::RequestToolCallConfirmationParams {
+                                        tool_call: acp_old::PushToolCallParams {
                                             label: "Edit".to_string(),
-                                            icon: acp::Icon::Pencil,
+                                            icon: acp_old::Icon::Pencil,
                                             content: None, // todo!()
                                             locations: patch
                                                 .codex_changes
                                                 .keys()
-                                                .map(|path| acp::ToolCallLocation {
+                                                .map(|path| acp_old::ToolCallLocation {
                                                     path: path.clone(),
                                                     line: None,
                                                 })
                                                 .collect(),
                                         },
-                                        confirmation: acp::ToolCallConfirmation::Edit {
+                                        confirmation: acp_old::ToolCallConfirmation::Edit {
                                             description: Some(patch.message),
                                         },
                                     }
@@ -293,18 +293,18 @@ impl AgentServer for Codex {
                                     .await?;
 
                                 let decision = match response.outcome {
-                                    acp::ToolCallConfirmationOutcome::Allow => {
+                                    acp_old::ToolCallConfirmationOutcome::Allow => {
                                         ReviewDecision::Approved
                                     }
-                                    acp::ToolCallConfirmationOutcome::AlwaysAllow
-                                    | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
-                                    | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => {
+                                    acp_old::ToolCallConfirmationOutcome::AlwaysAllow
+                                    | acp_old::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
+                                    | acp_old::ToolCallConfirmationOutcome::AlwaysAllowTool => {
                                         ReviewDecision::ApprovedForSession
                                     }
-                                    acp::ToolCallConfirmationOutcome::Reject => {
+                                    acp_old::ToolCallConfirmationOutcome::Reject => {
                                         ReviewDecision::Denied
                                     }
-                                    acp::ToolCallConfirmationOutcome::Cancel => {
+                                    acp_old::ToolCallConfirmationOutcome::Cancel => {
                                         ReviewDecision::Abort
                                     }
                                 };
@@ -340,7 +340,7 @@ impl AgentConnection for CodexAgentConnection {
     fn request_any(
         &self,
         params: acp_old::AnyAgentRequest,
-    ) -> LocalBoxFuture<'static, Result<acp::acp_old::AnyAgentResult>> {
+    ) -> LocalBoxFuture<'static, Result<acp_old::acp_old::AnyAgentResult>> {
         let client = self.codex_mcp.client();
         let root_dir = self.root_dir.clone();
         let cancel_request_tx = self.cancel_request_tx.clone();
@@ -350,7 +350,7 @@ impl AgentConnection for CodexAgentConnection {
             match params {
                 // todo: consider sending an empty request so we get the init response?
                 acp_old::AnyAgentRequest::InitializeParams(_) => Ok(
-                    acp_old::AnyAgentResult::InitializeResponse(acp::InitializeResponse {
+                    acp_old::AnyAgentResult::InitializeResponse(acp_old::InitializeResponse {
                         is_authenticated: true,
                         protocol_version: acp_old::ProtocolVersion::latest(),
                     }),
@@ -371,8 +371,8 @@ impl AgentConnection for CodexAgentConnection {
                                         .chunks
                                         .into_iter()
                                         .filter_map(|chunk| match chunk {
-                                            acp::UserMessageChunk::Text { text } => Some(text),
-                                            acp::UserMessageChunk::Path { .. } => {
+                                            acp_old::UserMessageChunk::Text { text } => Some(text),
+                                            acp_old::UserMessageChunk::Path { .. } => {
                                                 // todo!
                                                 None
                                             }
@@ -387,7 +387,7 @@ impl AgentConnection for CodexAgentConnection {
                         .await?;
 
                     Ok(acp_old::AnyAgentResult::SendUserMessageResponse(
-                        acp::SendUserMessageResponse,
+                        acp_old::SendUserMessageResponse,
                     ))
                 }
                 acp_old::AnyAgentRequest::CancelSendMessageParams(_) => {
@@ -398,7 +398,7 @@ impl AgentConnection for CodexAgentConnection {
                     }
 
                     Ok(acp_old::AnyAgentResult::CancelSendMessageResponse(
-                        acp::CancelSendMessageResponse,
+                        acp_old::CancelSendMessageResponse,
                     ))
                 }
             }
@@ -411,7 +411,7 @@ struct CodexAgentConnection {
     codex_mcp: Arc<context_server::ContextServer>,
     root_dir: PathBuf,
     cancel_request_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
-    tool_id_map: Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
+    tool_id_map: Rc<RefCell<HashMap<String, acp_old::ToolCallId>>>,
     _handler_task: Task<()>,
     _request_task: Task<()>,
     _zed_mcp: ZedMcpServer,
@@ -421,13 +421,13 @@ impl CodexAgentConnection {
     async fn handle_acp_notification(
         delegate: &OldAcpClientDelegate,
         event: AcpNotification,
-        tool_id_map: &Rc<RefCell<HashMap<String, acp::ToolCallId>>>,
+        tool_id_map: &Rc<RefCell<HashMap<String, acp_old::ToolCallId>>>,
     ) -> Result<()> {
         match event {
             AcpNotification::AgentMessage(message) => {
                 delegate
-                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
-                        chunk: acp::AssistantMessageChunk::Text {
+                    .stream_assistant_message_chunk(acp_old::StreamAssistantMessageChunkParams {
+                        chunk: acp_old::AssistantMessageChunk::Text {
                             text: message.message,
                         },
                     })
@@ -435,8 +435,8 @@ impl CodexAgentConnection {
             }
             AcpNotification::AgentReasoning(message) => {
                 delegate
-                    .stream_assistant_message_chunk(acp::StreamAssistantMessageChunkParams {
-                        chunk: acp::AssistantMessageChunk::Thought {
+                    .stream_assistant_message_chunk(acp_old::StreamAssistantMessageChunkParams {
+                        chunk: acp_old::AssistantMessageChunk::Thought {
                             thought: message.text,
                         },
                     })
@@ -444,11 +444,11 @@ impl CodexAgentConnection {
             }
             AcpNotification::McpToolCallBegin(event) => {
                 let result = delegate
-                    .push_tool_call(acp::PushToolCallParams {
+                    .push_tool_call(acp_old::PushToolCallParams {
                         label: format!("`{}: {}`", event.server, event.tool),
-                        icon: acp::Icon::Hammer,
+                        icon: acp_old::Icon::Hammer,
                         content: event.arguments.and_then(|args| {
-                            Some(acp::ToolCallContent::Markdown {
+                            Some(acp_old::ToolCallContent::Markdown {
                                 markdown: md_codeblock(
                                     "json",
                                     &serde_json::to_string_pretty(&args).ok()?,
@@ -473,26 +473,26 @@ impl CodexAgentConnection {
                             serde_json::from_value::<context_server::types::CallToolResponse>(value)
                         {
                             (
-                                acp::ToolCallStatus::Finished,
+                                acp_old::ToolCallStatus::Finished,
                                 mcp_tool_content_to_acp(response.content),
                             )
                         } else {
                             (
-                                acp::ToolCallStatus::Error,
-                                Some(acp::ToolCallContent::Markdown {
+                                acp_old::ToolCallStatus::Error,
+                                Some(acp_old::ToolCallContent::Markdown {
                                     markdown: "Failed to parse tool response".to_string(),
                                 }),
                             )
                         }
                     }
                     Err(error) => (
-                        acp::ToolCallStatus::Error,
-                        Some(acp::ToolCallContent::Markdown { markdown: error }),
+                        acp_old::ToolCallStatus::Error,
+                        Some(acp_old::ToolCallContent::Markdown { markdown: error }),
                     ),
                 };
 
                 delegate
-                    .update_tool_call(acp::UpdateToolCallParams {
+                    .update_tool_call(acp_old::UpdateToolCallParams {
                         tool_call_id: acp_call_id,
                         status,
                         content,
@@ -503,9 +503,9 @@ impl CodexAgentConnection {
                 let inner_command = strip_bash_lc_and_escape(&event.command);
 
                 let result = delegate
-                    .push_tool_call(acp::PushToolCallParams {
+                    .push_tool_call(acp_old::PushToolCallParams {
                         label: format!("`{}`", inner_command),
-                        icon: acp::Icon::Terminal,
+                        icon: acp_old::Icon::Terminal,
                         content: None,
                         locations: vec![],
                     })
@@ -549,14 +549,14 @@ impl CodexAgentConnection {
                 }
 
                 delegate
-                    .update_tool_call(acp::UpdateToolCallParams {
+                    .update_tool_call(acp_old::UpdateToolCallParams {
                         tool_call_id: acp_call_id,
                         status: if success {
-                            acp::ToolCallStatus::Finished
+                            acp_old::ToolCallStatus::Finished
                         } else {
-                            acp::ToolCallStatus::Error
+                            acp_old::ToolCallStatus::Error
                         },
-                        content: Some(acp::ToolCallContent::Markdown { markdown: content }),
+                        content: Some(acp_old::ToolCallContent::Markdown { markdown: content }),
                     })
                     .await?;
             }
@@ -569,14 +569,14 @@ impl CodexAgentConnection {
                     .unwrap_or_default();
 
                 let response = delegate
-                    .request_tool_call_confirmation(acp::RequestToolCallConfirmationParams {
-                        tool_call: acp::PushToolCallParams {
+                    .request_tool_call_confirmation(acp_old::RequestToolCallConfirmationParams {
+                        tool_call: acp_old::PushToolCallParams {
                             label: format!("`{}`", inner_command),
-                            icon: acp::Icon::Terminal,
+                            icon: acp_old::Icon::Terminal,
                             content: None,
                             locations: vec![],
                         },
-                        confirmation: acp::ToolCallConfirmation::Execute {
+                        confirmation: acp_old::ToolCallConfirmation::Execute {
                             command: inner_command,
                             root_command,
                             description: event.reason,
@@ -697,7 +697,7 @@ fn escape_command(command: &[String]) -> String {
     shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "))
 }
 
-fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::ToolCallContent> {
+fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp_old::ToolCallContent> {
     let mut content = String::new();
 
     for chunk in chunks {
@@ -716,7 +716,7 @@ fn mcp_tool_content_to_acp(chunks: Vec<ToolResponseContent>) -> Option<acp::Tool
     }
 
     if !content.is_empty() {
-        Some(acp::ToolCallContent::Markdown { markdown: content })
+        Some(acp_old::ToolCallContent::Markdown { markdown: content })
     } else {
         None
     }

crates/agent_servers/src/mcp_server.rs 🔗

@@ -266,27 +266,28 @@ impl ZedMcpServer {
                 None => delegate.push_tool_call(claude_tool.as_acp()).await?.id,
             };
 
-            let outcome = delegate
-                .request_existing_tool_call_confirmation(
-                    tool_call_id,
-                    claude_tool.confirmation(None),
-                )
-                .await?;
+            todo!("use regular request_tool_call_confirmation")
+            // let outcome = delegate
+            //     .request_existing_tool_call_confirmation(
+            //         tool_call_id,
+            //         claude_tool.confirmation(None),
+            //     )
+            //     .await?;
 
-            match outcome {
-                acp::ToolCallConfirmationOutcome::Allow
-                | acp::ToolCallConfirmationOutcome::AlwaysAllow
-                | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
-                | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
-                    behavior: PermissionToolBehavior::Allow,
-                    updated_input: params.input,
-                }),
-                acp::ToolCallConfirmationOutcome::Reject
-                | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
-                    behavior: PermissionToolBehavior::Deny,
-                    updated_input: params.input,
-                }),
-            }
+            // match outcome {
+            //     acp::ToolCallConfirmationOutcome::Allow
+            //     | acp::ToolCallConfirmationOutcome::AlwaysAllow
+            //     | acp::ToolCallConfirmationOutcome::AlwaysAllowMcpServer
+            //     | acp::ToolCallConfirmationOutcome::AlwaysAllowTool => Ok(PermissionToolResponse {
+            //         behavior: PermissionToolBehavior::Allow,
+            //         updated_input: params.input,
+            //     }),
+            //     acp::ToolCallConfirmationOutcome::Reject
+            //     | acp::ToolCallConfirmationOutcome::Cancel => Ok(PermissionToolResponse {
+            //         behavior: PermissionToolBehavior::Deny,
+            //         updated_input: params.input,
+            //     }),
+            // }
         })
     }
 }