@@ -4,8 +4,8 @@ use assistant_tool::ToolWorkingSet;
use client::proto::LspWorkProgress;
use collections::HashMap;
use dap::DapRegistry;
-use futures::channel::{mpsc, oneshot};
-use futures::{FutureExt, StreamExt as _};
+use futures::channel::mpsc;
+use futures::{FutureExt, StreamExt as _, select_biased};
use gpui::{App, AsyncApp, Entity, Task};
use handlebars::Handlebars;
use language::{DiagnosticSeverity, OffsetRangeExt};
@@ -35,6 +35,8 @@ pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
pub const REPOS_DIR: &str = "./crates/eval/repos";
pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
+const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
+
#[derive(Clone, Debug, Deserialize)]
pub struct ExampleBase {
pub url: String,
@@ -277,77 +279,87 @@ impl Example {
let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
Mutex::new(HashMap::default()).into();
- let (tx, rx) = oneshot::channel();
- let mut tx = Some(tx);
+ let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
+
+ let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
+ thread_event_tx.unbounded_send(event.clone()).log_err();
+ });
- let subscription = cx.subscribe(&thread, {
+ let event_handler_task = cx.spawn({
let log_file = this.log_file.clone();
let name = this.name.clone();
let tool_use_counts = tool_use_counts.clone();
- move |thread, event: &ThreadEvent, cx| {
- let mut log_file = log_file.lock().unwrap();
-
- match event {
- ThreadEvent::Stopped(reason) => match reason {
- Ok(StopReason::EndTurn) => {
- if let Some(tx) = tx.take() {
- tx.send(Ok(())).ok();
- }
+ let thread = thread.downgrade();
+ async move |cx| {
+ loop {
+ let event = select_biased! {
+ event = thread_event_rx.next() => event,
+ _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
+ return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
}
- Ok(StopReason::MaxTokens) => {
- if let Some(tx) = tx.take() {
- tx.send(Err(anyhow!("Exceeded maximum tokens"))).ok();
+ };
+ let Some(event) = event else {
+ return Err(anyhow!("ThreadEvent channel ended early"));
+ };
+
+ let mut log_file = log_file.lock().unwrap();
+
+ match event {
+ ThreadEvent::Stopped(reason) => match reason {
+ Ok(StopReason::EndTurn) => {
+ return Ok(());
}
- }
- Ok(StopReason::ToolUse) => {}
- Err(error) => {
- if let Some(tx) = tx.take() {
- tx.send(Err(anyhow!(error.clone()))).ok();
+ Ok(StopReason::MaxTokens) => {
+ return Err(anyhow!("Exceeded maximum tokens"));
}
+ Ok(StopReason::ToolUse) => {}
+ Err(error) => {
+ return Err(anyhow!(error.clone()));
+ }
+ },
+ ThreadEvent::ShowError(thread_error) => {
+ break Err(anyhow!(thread_error.clone()));
}
- },
- ThreadEvent::ShowError(thread_error) => {
- if let Some(tx) = tx.take() {
- tx.send(Err(anyhow!(thread_error.clone()))).ok();
+ ThreadEvent::StreamedAssistantText(_, chunk) => {
+ write!(&mut log_file, "{}", chunk).log_err();
}
- }
- ThreadEvent::StreamedAssistantText(_, chunk) => {
- write!(&mut log_file, "{}", chunk).log_err();
- }
- ThreadEvent::StreamedAssistantThinking(_, chunk) => {
- write!(&mut log_file, "{}", chunk).log_err();
- }
- ThreadEvent::UsePendingTools { tool_uses } => {
- writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
- for tool_use in tool_uses {
- writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
- .log_err();
+ ThreadEvent::StreamedAssistantThinking(_, chunk) => {
+ write!(&mut log_file, "{}", chunk).log_err();
}
- }
- ThreadEvent::ToolFinished {
- tool_use_id,
- pending_tool_use,
- ..
- } => {
- if let Some(tool_use) = pending_tool_use {
- let message = format!("TOOL FINISHED: {}", tool_use.name);
- println!("{name}> {message}");
- writeln!(&mut log_file, "\n{}", message).log_err();
+ ThreadEvent::UsePendingTools { tool_uses } => {
+ writeln!(&mut log_file, "\n\nUSING TOOLS:").log_err();
+ for tool_use in tool_uses {
+ writeln!(&mut log_file, "{}: {}", tool_use.name, tool_use.input)
+ .log_err();
+ }
}
- if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
- writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
- let mut tool_use_counts = tool_use_counts.lock().unwrap();
- *tool_use_counts
- .entry(tool_result.tool_name.clone())
- .or_insert(0) += 1;
+ ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ ..
+ } => {
+ if let Some(tool_use) = pending_tool_use {
+ let message = format!("TOOL FINISHED: {}", tool_use.name);
+ println!("{name}> {message}");
+ writeln!(&mut log_file, "\n{}", message).log_err();
+ }
+ thread.update(cx, |thread, _cx| {
+ if let Some(tool_result) = thread.tool_result(&tool_use_id) {
+ writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
+ let mut tool_use_counts = tool_use_counts.lock().unwrap();
+ *tool_use_counts
+ .entry(tool_result.tool_name.clone())
+ .or_insert(0) += 1;
+ }
+ })?;
}
+ _ => {}
}
- _ => {}
- }
- log_file.flush().log_err();
+ log_file.flush().log_err();
+ }
}
- })?;
+ });
thread.update(cx, |thread, cx| {
let context = vec![];
@@ -355,7 +367,7 @@ impl Example {
thread.send_to_model(model, RequestKind::Chat, cx);
})?;
- rx.await??;
+ event_handler_task.await?;
if let Some((_, lsp_store)) = lsp_open_handle_and_store.as_ref() {
wait_for_lang_server(lsp_store, this.name.clone(), cx).await?;