agent: Remove last turn after a refusal (#31220)

Marshall Bowers created

This is a follow-up to https://github.com/zed-industries/zed/pull/31217
that removes the last turn after we get a `refusal` stop reason, as
advised by the Anthropic docs.

Meant to include it in that PR, but accidentally merged it before
pushing these changes 🤦🏻‍♂️.

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs | 29 ++++++++++++++++++++++++++++-
1 file changed, 28 insertions(+), 1 deletion(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -1698,7 +1698,34 @@ impl Thread {
                                     project.set_agent_location(None, cx);
                                 });
 
-                                cx.emit (ThreadEvent::ShowError(ThreadError::Message {
+                                // Remove the turn that was refused.
+                                //
+                                // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal
+                                {
+                                    let mut messages_to_remove = Vec::new();
+
+                                    for (ix, message) in thread.messages.iter().enumerate().rev() {
+                                        messages_to_remove.push(message.id);
+
+                                        if message.role == Role::User {
+                                            if ix == 0 {
+                                                break;
+                                            }
+
+                                            if let Some(prev_message) = thread.messages.get(ix - 1) {
+                                                if prev_message.role == Role::Assistant {
+                                                    break;
+                                                }
+                                            }
+                                        }
+                                    }
+
+                                    for message_id in messages_to_remove {
+                                        thread.delete_message(message_id, cx);
+                                    }
+                                }
+
+                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
                                     header: "Language model refusal".into(),
                                     message: "Model refused to generate content for safety reasons.".into(),
                                 }));