Agent Eval: Fail example when there are no events in 2 minutes (#28725)

Michael Sloan created

Release Notes:

- N/A

Change summary

crates/eval/src/example.rs | 130 +++++++++++++++++++++------------------
1 file changed, 71 insertions(+), 59 deletions(-)

Detailed changes

crates/eval/src/example.rs 🔗

@@ -4,8 +4,8 @@ use assistant_tool::ToolWorkingSet;
 use client::proto::LspWorkProgress;
 use collections::HashMap;
 use dap::DapRegistry;
-use futures::channel::{mpsc, oneshot};
-use futures::{FutureExt, StreamExt as _};
+use futures::channel::mpsc;
+use futures::{FutureExt, StreamExt as _, select_biased};
 use gpui::{App, AsyncApp, Entity, Task};
 use handlebars::Handlebars;
 use language::{DiagnosticSeverity, OffsetRangeExt};
@@ -35,6 +35,8 @@ pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
 pub const REPOS_DIR: &str = "./crates/eval/repos";
 pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
 
+const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
+
 #[derive(Clone, Debug, Deserialize)]
 pub struct ExampleBase {
     pub url: String,
@@ -277,77 +279,87 @@ impl Example {
             let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
                 Mutex::new(HashMap::default()).into();
 
-            let (tx, rx) = oneshot::channel();
-            let mut tx = Some(tx);
+            let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
+
+            let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
+                thread_event_tx.unbounded_send(event.clone()).log_err();
+            });
 
-            let subscription = cx.subscribe(&thread, {
+            let event_handler_task = cx.spawn({
                 let log_file = this.log_file.clone();
                 let name = this.name.clone();
                 let tool_use_counts = tool_use_counts.clone();
-                move |thread, event: &ThreadEvent, cx| {
-                    let mut log_file = log_file.lock().unwrap();
-
-                    match event {
-                        ThreadEvent::Stopped(reason) => match reason {
-                            Ok(StopReason::EndTurn) => {
-                                if let Some(tx) = tx.take() {
-                                    tx.send(Ok(())).ok();
-                                }
+                let thread = thread.downgrade();
+                async move |cx| {
+                    loop {
+                        let event = select_biased! {
+                            event = thread_event_rx.next() => event,
+                            _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
+                                return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
                             }
-                            Ok(StopReason::MaxTokens) => {
-                                if let Some(tx) = tx.take() {
-                                    tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
+                        };
+                        let Some(event) = event else {
+                            return Err(anyhow!("ThreadEvent channel ended early"));
+                        };
+
+                        let mut log_file = log_file.lock().unwrap();
+
+                        match event {
+                            ThreadEvent::Stopped(reason) => match reason {
+                                Ok(StopReason::EndTurn) => {
+                                    return Ok(());
                                 }
-                            }
-                            Ok(StopReason::ToolUse) => {}
-                            Err(error) => {
-                                if let Some(tx) = tx.take() {
-                                    tx.send(Err(anyhow!(error.clone()))).ok();
+                                Ok(StopReason::MaxTokens) => {
+                                    return Err(anyhow!("Exceeded maximum tokens"));
                                 }
+                                Ok(StopReason::ToolUse) => {}
+                                Err(error) => {
+                                    return Err(anyhow!(error.clone()));
+                                }
+                            },
+                            ThreadEvent::ShowError(thread_error) => {
+                                break Err(anyhow!(thread_error.clone()));
                             }
-                        },
-                        ThreadEvent::ShowError(thread_error) => {
-                            if let Some(tx) = tx.take() {
-                                tx.send(Err(anyhow!(thread_error.clone()))).ok();
+                            ThreadEvent::StreamedAssistantText(_, chunk) => {
+                                write!(&mut log_file, "{}", chunk).log_err();
                             }
-                        }
-                        ThreadEvent::StreamedAssistantText(_, chunk) => {
-                            write!(&mut log_file, "{}", chunk).log_err();
-                        }
-                        ThreadEvent::StreamedAssistantThinking(_, chunk) => {
-                            write!(&mut log_file, "{}", chunk).log_err();
-                        }
-                        ThreadEvent::UsePendingTools { tool_uses } => {
-                            writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
-                            for tool_use in tool_uses {
-                                writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
-                                    .log_err();
+                            ThreadEvent::StreamedAssistantThinking(_, chunk) => {
+                                write!(&mut log_file, "{}", chunk).log_err();
                             }
-                        }
-                        ThreadEvent::ToolFinished {
-                            tool_use_id,
-                            pending_tool_use,
-                            ..
-                        } => {
-                            if let Some(tool_use) = pending_tool_use {
-                                let message = format!("TOOL FINISHED: {}", tool_use.name);
-                                println!("{name}> {message}");
-                                writeln!(&mut log_file, "\n{}", message).log_err();
+                            ThreadEvent::UsePendingTools { tool_uses } => {
+                                writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
+                                for tool_use in tool_uses {
+                                    writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
+                                        .log_err();
+                                }
                             }
-                            if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
-                                writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
-                                let mut tool_use_counts = tool_use_counts.lock().unwrap();
-                                *tool_use_counts
-                                    .entry(tool_result.tool_name.clone())
-                                    .or_insert(0) += 1;
+                            ThreadEvent::ToolFinished {
+                                tool_use_id,
+                                pending_tool_use,
+                                ..
+                            } => {
+                                if let Some(tool_use) = pending_tool_use {
+                                    let message = format!("TOOL FINISHED: {}", tool_use.name);
+                                    println!("{name}> {message}");
+                                    writeln!(&mut log_file, "\n{}", message).log_err();
+                                }
+                                thread.update(cx, |thread, _cx| {
+                                    if let Some(tool_result) = thread.tool_result(&tool_use_id) {
+                                        writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
+                                        let mut tool_use_counts = tool_use_counts.lock().unwrap();
+                                        *tool_use_counts
+                                            .entry(tool_result.tool_name.clone())
+                                            .or_insert(0) += 1;
+                                    }
+                                })?;
                             }
+                            _ => {}
                         }
-                        _ => {}
-                    }
 
-                    log_file.flush().log_err();
+                        log_file.flush().log_err();
+                    }
                 }
-            })?;
+            });
 
             thread.update(cx, |thread, cx| {
                 let context = vec![];
@@ -355,7 +367,7 @@ impl Example {
                 thread.send_to_model(model, RequestKind::Chat, cx);
             })?;
 
-            rx.await??;
+            event_handler_task.await?;
 
             if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
                 wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;