Fix eval judging missing final response (#29638)

Richard Feldman created

Fixed issue where eval thread judges were not considering the last
response in the thread.

The problem was that they were getting the full list of messages from
`last_request`, which (being a request!) did not have the response yet.

Release Notes:

- N/A

Change summary

crates/agent/src/assistant.rs |  2 
crates/eval/src/instance.rs   | 80 +++++++++++++++++++++++++-----------
2 files changed, 56 insertions(+), 26 deletions(-)

Detailed changes

crates/agent/src/assistant.rs 🔗

@@ -41,7 +41,7 @@ use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}
 pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
 pub use crate::context::{ContextLoadResult, LoadedContext};
 pub use crate::inline_assistant::InlineAssistant;
-pub use crate::thread::{Message, Thread, ThreadEvent};
+pub use crate::thread::{Message, MessageSegment, Thread, ThreadEvent};
 pub use crate::thread_store::ThreadStore;
 pub use agent_diff::{AgentDiff, AgentDiffToolbar};
 

crates/eval/src/instance.rs 🔗

@@ -1,4 +1,4 @@
-use agent::ThreadStore;
+use agent::{Message, MessageSegment, ThreadStore};
 use anyhow::{Context, Result, anyhow, bail};
 use assistant_tool::ToolWorkingSet;
 use client::proto::LspWorkProgress;
@@ -60,7 +60,7 @@ pub struct RunOutput {
     pub response_count: usize,
     pub token_usage: TokenUsage,
     pub tool_metrics: ToolMetrics,
-    pub last_request: LanguageModelRequest,
+    pub all_messages: String,
     pub programmatic_assertions: AssertionsReport,
 }
 
@@ -309,19 +309,15 @@ impl ExampleInstance {
             let thread_store = thread_store.await?;
             let thread =
                 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
-            let last_request = Rc::new(RefCell::new(None));
 
             thread.update(cx, |thread, _cx| {
                 let mut request_count = 0;
-                let last_request = Rc::clone(&last_request);
                 let previous_diff = Rc::new(RefCell::new("".to_string()));
                 let example_output_dir = this.run_directory.clone();
                 let last_diff_file_path = last_diff_file_path.clone();
                 let messages_json_file_path = example_output_dir.join("last.messages.json");
                 let this = this.clone();
                 thread.set_request_callback(move |request, response_events| {
-                    *last_request.borrow_mut() = Some(request.clone());
-
                     request_count += 1;
                     let messages_file_path = example_output_dir.join(format!("{request_count}.messages.md"));
                     let diff_file_path = example_output_dir.join(format!("{request_count}.diff"));
@@ -397,10 +393,6 @@ impl ExampleInstance {
 
             }
 
-            let Some(last_request) = last_request.borrow_mut().take() else {
-                return Err(anyhow!("No requests ran."));
-            };
-
             if let Some(diagnostics_before) = &diagnostics_before {
                 fs::write(this.run_directory.join("diagnostics_before.txt"), diagnostics_before)?;
             }
@@ -423,7 +415,7 @@ impl ExampleInstance {
                     response_count,
                     token_usage: thread.cumulative_token_usage(),
                     tool_metrics: example_cx.tool_metrics.lock().unwrap().clone(),
-                    last_request,
+                    all_messages: messages_to_markdown(thread.messages()),
                     programmatic_assertions: example_cx.assertions,
                 }
             })
@@ -526,23 +518,23 @@ impl ExampleInstance {
 
         if thread_assertions.is_empty() {
             return (
-                "No diff assertions".to_string(),
+                "No thread assertions".to_string(),
                 AssertionsReport::default(),
             );
         }
 
         let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
-        let judge_diff_prompt_name = "judge_thread_prompt";
+        let judge_thread_prompt_name = "judge_thread_prompt";
         let mut hbs = Handlebars::new();
-        hbs.register_template_string(judge_diff_prompt_name, judge_thread_prompt)
+        hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)
             .unwrap();
 
-        let request_markdown = RequestMarkdown::new(&run_output.last_request);
+        let complete_messages = &run_output.all_messages;
         let to_prompt = |assertion: String| {
             hbs.render(
-                judge_diff_prompt_name,
+                judge_thread_prompt_name,
                 &JudgeThreadInput {
-                    messages: request_markdown.messages.clone(),
+                    messages: complete_messages.clone(),
                     assertion,
                 },
             )
@@ -817,6 +809,51 @@ pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
     }
 }
 
+fn messages_to_markdown<'a>(message_iter: impl IntoIterator<Item = &'a Message>) -> String {
+    let mut messages = String::new();
+    let mut assistant_message_number: u32 = 1;
+
+    for message in message_iter {
+        push_role(&message.role, &mut messages, &mut assistant_message_number);
+
+        for segment in &message.segments {
+            match segment {
+                MessageSegment::Text(text) => {
+                    messages.push_str(&text);
+                    messages.push_str("\n\n");
+                }
+                MessageSegment::Thinking { text, signature } => {
+                    messages.push_str("**Thinking**:\n\n");
+                    if let Some(sig) = signature {
+                        messages.push_str(&format!("Signature: {}\n\n", sig));
+                    }
+                    messages.push_str(&text);
+                    messages.push_str("\n");
+                }
+                MessageSegment::RedactedThinking(items) => {
+                    messages.push_str(&format!(
+                        "**Redacted Thinking**: {} item(s)\n\n",
+                        items.len()
+                    ));
+                }
+            }
+        }
+    }
+
+    messages
+}
+
+fn push_role(role: &Role, buf: &mut String, assistant_message_number: &mut u32) {
+    match role {
+        Role::System => buf.push_str("# ⚙️ SYSTEM\n\n"),
+        Role::User => buf.push_str("# 👤 USER\n\n"),
+        Role::Assistant => {
+            buf.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
+            *assistant_message_number = *assistant_message_number + 1;
+        }
+    }
+}
+
 pub async fn send_language_model_request(
     model: Arc<dyn LanguageModel>,
     request: LanguageModelRequest,
@@ -875,14 +912,7 @@ impl RequestMarkdown {
 
         // Print the messages
         for message in &request.messages {
-            match message.role {
-                Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
-                Role::User => messages.push_str("# 👤 USER\n\n"),
-                Role::Assistant => {
-                    messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
-                    assistant_message_number += 1;
-                }
-            };
+            push_role(&message.role, &mut messages, &mut assistant_message_number);
 
             for content in &message.content {
                 match content {