diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index f9a955eb9f5c7f5cd5ab077ed3c3afd9dfcd4b8b..804e4683a7cc20ac2bcd80f10139d641aa864b98 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -804,6 +804,7 @@ pub enum AcpThreadEvent { Error, LoadError(LoadError), PromptCapabilitiesUpdated, + Refusal, } impl EventEmitter for AcpThread {} @@ -1569,15 +1570,42 @@ impl AcpThread { this.send_task.take(); } - // Truncate entries if the last prompt was refused. + // Handle refusal - distinguish between user prompt and tool call refusals if let Ok(Ok(acp::PromptResponse { stop_reason: acp::StopReason::Refusal, })) = result - && let Some((ix, _)) = this.last_user_message() { - let range = ix..this.entries.len(); - this.entries.truncate(ix); - cx.emit(AcpThreadEvent::EntriesRemoved(range)); + if let Some((user_msg_ix, _)) = this.last_user_message() { + // Check if there's a completed tool call with results after the last user message + // This indicates the refusal is in response to tool output, not the user's prompt + let has_completed_tool_call_after_user_msg = + this.entries.iter().skip(user_msg_ix + 1).any(|entry| { + if let AgentThreadEntry::ToolCall(tool_call) = entry { + // Check if the tool call has completed and has output + matches!(tool_call.status, ToolCallStatus::Completed) + && tool_call.raw_output.is_some() + } else { + false + } + }); + + if has_completed_tool_call_after_user_msg { + // Refusal is due to tool output - don't truncate, just notify + // The model refused based on what the tool returned + cx.emit(AcpThreadEvent::Refusal); + } else { + // User prompt was refused - truncate back to before the user message + let range = user_msg_ix..this.entries.len(); + if range.start < range.end { + this.entries.truncate(user_msg_ix); + cx.emit(AcpThreadEvent::EntriesRemoved(range)); + } + cx.emit(AcpThreadEvent::Refusal); + } + } else { + // No user message found, treat as general refusal + cx.emit(AcpThreadEvent::Refusal); + } } cx.emit(AcpThreadEvent::Stopped); @@ -2681,6 +2709,187 @@ mod tests { assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]); } + #[gpui::test] + async fn test_tool_result_refusal(cx: &mut TestAppContext) { + use std::sync::atomic::AtomicUsize; + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, None, cx).await; + + // Create a connection that simulates refusal after tool result + let prompt_count = Arc::new(AtomicUsize::new(0)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let prompt_count = prompt_count.clone(); + move |_request, thread, mut cx| { + let count = prompt_count.fetch_add(1, SeqCst); + async move { + if count == 0 { + // First prompt: Generate a tool call with result + thread.update(&mut cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("tool1".into()), + title: "Test Tool".into(), + kind: acp::ToolKind::Fetch, + status: acp::ToolCallStatus::Completed, + content: vec![], + locations: vec![], + raw_input: Some(serde_json::json!({"query": "test"})), + raw_output: Some( + serde_json::json!({"result": "inappropriate content"}), + ), + }), + cx, + ) + .unwrap(); + })?; + + // Now return refusal because of the tool result + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + }) + } else { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + } + .boxed_local() + } + })); + + let thread = cx + .update(|cx| connection.new_thread(project, Path::new("/test"), cx)) + .await + .unwrap(); + + // Track if we see a Refusal event + let saw_refusal_event = Arc::new(std::sync::Mutex::new(false)); + let saw_refusal_event_captured = saw_refusal_event.clone(); + thread.update(cx, |_thread, cx| { + cx.subscribe( + &thread, + move |_thread, _event_thread, event: &AcpThreadEvent, _cx| { + if matches!(event, AcpThreadEvent::Refusal) { + *saw_refusal_event_captured.lock().unwrap() = true; + } + }, + ) + .detach(); + }); + + // Send a user message - this will trigger tool call and then refusal + let send_task = thread.update(cx, |thread, cx| { + thread.send( + vec![acp::ContentBlock::Text(acp::TextContent { + text: "Hello".into(), + annotations: None, + })], + cx, + ) + }); + cx.background_executor.spawn(send_task).detach(); + cx.run_until_parked(); + + // Verify that: + // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt) + // 2. The user message was NOT truncated + assert!( + *saw_refusal_event.lock().unwrap(), + "Refusal event should be emitted for tool result refusals" + ); + + thread.read_with(cx, |thread, _| { + let entries = thread.entries(); + assert!(entries.len() >= 2, "Should have user message and tool call"); + + // Verify user message is still there + assert!( + matches!(entries[0], AgentThreadEntry::UserMessage(_)), + "User message should not be truncated" + ); + + // Verify tool call is there with result + if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] { + assert!( + tool_call.raw_output.is_some(), + "Tool call should have output" + ); + } else { + panic!("Expected tool call at index 1"); + } + }); + } + + #[gpui::test] + async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, None, cx).await; + + let refuse_next = Arc::new(AtomicBool::new(false)); + let connection = Rc::new(FakeAgentConnection::new().on_user_message({ + let refuse_next = refuse_next.clone(); + move |_request, _thread, _cx| { + if refuse_next.load(SeqCst) { + async move { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + }) + } + .boxed_local() + } else { + async move { + Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + }) + } + .boxed_local() + } + } + })); + + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Track if we see a Refusal event + let saw_refusal_event = Arc::new(std::sync::Mutex::new(false)); + let saw_refusal_event_captured = saw_refusal_event.clone(); + thread.update(cx, |_thread, cx| { + cx.subscribe( + &thread, + move |_thread, _event_thread, event: &AcpThreadEvent, _cx| { + if matches!(event, AcpThreadEvent::Refusal) { + *saw_refusal_event_captured.lock().unwrap() = true; + } + }, + ) + .detach(); + }); + + // Send a message that will be refused + refuse_next.store(true, SeqCst); + cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx))) + .await + .unwrap(); + + // Verify that a Refusal event WAS emitted for user prompt refusal + assert!( + *saw_refusal_event.lock().unwrap(), + "Refusal event should be emitted for user prompt refusals" + ); + + // Verify the message was truncated (user prompt refusal) + thread.read_with(cx, |thread, cx| { + assert_eq!(thread.to_markdown(cx), ""); + }); + } + #[gpui::test] async fn test_refusal(cx: &mut TestAppContext) { init_test(cx); @@ -2744,8 +2953,8 @@ mod tests { ); }); - // Simulate refusing the second message, ensuring the conversation gets - // truncated to before sending it. + // Simulate refusing the second message. The message should be truncated + // when a user prompt is refused. refuse_next.store(true, SeqCst); cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx))) .await diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index e52101113a61c7379be54e25f1784ac16b660200..e8e0d4be7f9dd06f2a7b98761dc2b6287f968ba4 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -36,6 +36,14 @@ impl AcpModelSelectorPopover { pub fn toggle(&self, window: &mut Window, cx: &mut Context) { self.menu_handle.toggle(window, cx); } + + pub fn active_model_name(&self, cx: &App) -> Option { + self.selector + .read(cx) + .delegate + .active_model() + .map(|model| model.name.clone()) + } } impl Render for AcpModelSelectorPopover { diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index dc8abd99ae9e9afb07a5e3360e1216a07a528d01..60b3166a57aebc02fba82d4b350de0e48b84ef94 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -78,10 +78,12 @@ enum ThreadFeedback { Negative, } +#[derive(Debug)] enum ThreadError { PaymentRequired, ModelRequestLimitReached(cloud_llm_client::Plan), ToolUseLimitReached, + Refusal, AuthenticationRequired(SharedString), Other(SharedString), } @@ -1255,6 +1257,14 @@ impl AcpThreadView { cx, ); } + AcpThreadEvent::Refusal => { + self.thread_retry_status.take(); + self.thread_error = Some(ThreadError::Refusal); + let model_or_agent_name = self.get_current_model_name(cx); + let notification_message = + format!("{} refused to respond to this request", model_or_agent_name); + self.notify_with_sound(¬ification_message, IconName::Warning, window, cx); + } AcpThreadEvent::Error => { self.thread_retry_status.take(); self.notify_with_sound( @@ -4740,6 +4750,7 @@ impl AcpThreadView { fn render_thread_error(&self, window: &mut Window, cx: &mut Context) -> Option
{ let content = match self.thread_error.as_ref()? { ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx), + ThreadError::Refusal => self.render_refusal_error(cx), ThreadError::AuthenticationRequired(error) => { self.render_authentication_required_error(error.clone(), cx) } @@ -4755,6 +4766,43 @@ impl AcpThreadView { Some(div().child(content)) } + fn get_current_model_name(&self, cx: &App) -> SharedString { + // For native agent (Zed Agent), use the specific model name (e.g., "Claude 3.5 Sonnet") + // For ACP agents, use the agent name (e.g., "Claude Code", "Gemini CLI") + // This provides better clarity about what refused the request + if self + .agent + .clone() + .downcast::() + .is_some() + { + // Native agent - use the model name + self.model_selector + .as_ref() + .and_then(|selector| selector.read(cx).active_model_name(cx)) + .unwrap_or_else(|| SharedString::from("The model")) + } else { + // ACP agent - use the agent name (e.g., "Claude Code", "Gemini CLI") + self.agent.name() + } + } + + fn render_refusal_error(&self, cx: &mut Context<'_, Self>) -> Callout { + let model_or_agent_name = self.get_current_model_name(cx); + let refusal_message = format!( + "{} refused to respond to this prompt. This can happen when a model believes the prompt violates its content policy or safety guidelines, so rephrasing it can sometimes address the issue.", + model_or_agent_name + ); + + Callout::new() + .severity(Severity::Error) + .title("Request Refused") + .icon(IconName::XCircle) + .description(refusal_message.clone()) + .actions_slot(self.create_copy_button(&refusal_message)) + .dismiss_action(self.dismiss_error_button(cx)) + } + fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout { let can_resume = self .thread() @@ -5382,6 +5430,33 @@ pub(crate) mod tests { ); } + #[gpui::test] + async fn test_refusal_handling(cx: &mut TestAppContext) { + init_test(cx); + + let (thread_view, cx) = + setup_thread_view(StubAgentServer::new(RefusalAgentConnection), cx).await; + + let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone()); + message_editor.update_in(cx, |editor, window, cx| { + editor.set_text("Do something harmful", window, cx); + }); + + thread_view.update_in(cx, |thread_view, window, cx| { + thread_view.send(window, cx); + }); + + cx.run_until_parked(); + + // Check that the refusal error is set + thread_view.read_with(cx, |thread_view, _cx| { + assert!( + matches!(thread_view.thread_error, Some(ThreadError::Refusal)), + "Expected refusal error to be set" + ); + }); + } + #[gpui::test] async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { init_test(cx); @@ -5617,6 +5692,68 @@ pub(crate) mod tests { } } + /// Simulates a model which always returns a refusal response + #[derive(Clone)] + struct RefusalAgentConnection; + + impl AgentConnection for RefusalAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + _cwd: &Path, + cx: &mut gpui::App, + ) -> Task>> { + Task::ready(Ok(cx.new(|cx| { + let action_log = cx.new(|_| ActionLog::new(project.clone())); + AcpThread::new( + "RefusalAgentConnection", + self, + project, + action_log, + SessionId("test".into()), + watch::Receiver::constant(acp::PromptCapabilities { + image: true, + audio: true, + embedded_context: true, + }), + Vec::new(), + cx, + ) + }))) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] + } + + fn authenticate( + &self, + _method_id: acp::AuthMethodId, + _cx: &mut App, + ) -> Task> { + unimplemented!() + } + + fn prompt( + &self, + _id: Option, + _params: acp::PromptRequest, + _cx: &mut App, + ) -> Task> { + Task::ready(Ok(acp::PromptResponse { + stop_reason: acp::StopReason::Refusal, + })) + } + + fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { + unimplemented!() + } + + fn into_any(self: Rc) -> Rc { + self + } + } + pub(crate) fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/agent_ui/src/active_thread.rs b/crates/agent_ui/src/active_thread.rs index 371a59e7eb9eb88dc5200251f971ef851162b630..fbba3eaffdd818bd1496b83f9f3081cbf52735ed 100644 --- a/crates/agent_ui/src/active_thread.rs +++ b/crates/agent_ui/src/active_thread.rs @@ -1001,8 +1001,22 @@ impl ActiveThread { // Don't notify for intermediate tool use } Ok(StopReason::Refusal) => { + let model_name = self + .thread + .read(cx) + .configured_model() + .map(|configured| configured.model.name().0.to_string()) + .unwrap_or_else(|| "The model".to_string()); + let refusal_message = format!( + "{} refused to respond to this prompt. This can happen when a model believes the prompt violates its content policy or safety guidelines, so rephrasing it can sometimes address the issue.", + model_name + ); + self.last_error = Some(ThreadError::Message { + header: SharedString::from("Request Refused"), + message: SharedString::from(refusal_message), + }); self.notify_with_sound( - "Language model refused to respond", + format!("{} refused to respond", model_name), IconName::Warning, window, cx, diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 74bcb266d52ac25c91f3243c3e76f1e1f25d770e..f9d7321ca8dd72b791a462d50f262ce0f5531fd5 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1517,7 +1517,10 @@ impl AgentDiff { self.update_reviewing_editors(workspace, window, cx); } } - AcpThreadEvent::Stopped | AcpThreadEvent::Error | AcpThreadEvent::LoadError(_) => { + AcpThreadEvent::Stopped + | AcpThreadEvent::Error + | AcpThreadEvent::LoadError(_) + | AcpThreadEvent::Refusal => { self.update_reviewing_editors(workspace, window, cx); } AcpThreadEvent::TitleUpdated diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 2383963d6c26c6afea1b03d7f28b29bf3a9b4223..cfa5b56358863ece6ab1f6dd024e7be365766853 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -3532,6 +3532,11 @@ impl AgentPanel { ) -> AnyElement { let message_with_header = format!("{}\n{}", header, message); + // Don't show Retry button for refusals + let is_refusal = header == "Request Refused"; + let retry_button = self.render_retry_button(thread); + let copy_button = self.create_copy_button(message_with_header); + Callout::new() .severity(Severity::Error) .icon(IconName::XCircle) @@ -3540,8 +3545,8 @@ impl AgentPanel { .actions_slot( h_flex() .gap_0p5() - .child(self.render_retry_button(thread)) - .child(self.create_copy_button(message_with_header)), + .when(!is_refusal, |this| this.child(retry_button)) + .child(copy_button), ) .dismiss_action(self.dismiss_error_button(thread, cx)) .into_any_element()