Track tool use counts (#28722)

Michael Sloan created

Release Notes:

- N/A

Change summary

Cargo.lock                 |  1 +
crates/eval/Cargo.toml     |  1 +
crates/eval/src/example.rs | 17 ++++++++++++++---
3 files changed, 16 insertions(+), 3 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4882,6 +4882,7 @@ dependencies = [
  "chrono",
  "clap",
  "client",
+ "collections",
  "context_server",
  "dap",
  "env_logger 0.11.8",

crates/eval/Cargo.toml 🔗

@@ -14,6 +14,7 @@ assistant_tools.workspace = true
 chrono.workspace = true
 clap.workspace = true
 client.workspace = true
+collections.workspace = true
 context_server.workspace = true
 dap.workspace = true
 env_logger.workspace = true

crates/eval/src/example.rs 🔗

@@ -2,6 +2,7 @@ use agent::{RequestKind, ThreadEvent, ThreadStore};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::ToolWorkingSet;
 use client::proto::LspWorkProgress;
+use collections::HashMap;
 use dap::DapRegistry;
 use futures::channel::{mpsc, oneshot};
 use futures::{FutureExt, StreamExt as _};
@@ -63,6 +64,7 @@ pub struct RunOutput {
     pub diagnostics: String,
     pub response_count: usize,
     pub token_usage: TokenUsage,
+    pub tool_use_counts: HashMap<Arc<str>, u32>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -270,12 +272,16 @@ impl Example {
                 log_file.flush().log_err();
             }
 
+            let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
+                Mutex::new(HashMap::default()).into();
+
             let (tx, rx) = oneshot::channel();
             let mut tx = Some(tx);
 
-            let _subscription = cx.subscribe(&thread, {
+            let subscription = cx.subscribe(&thread, {
                 let log_file = this.log_file.clone();
                 let name = this.name.clone();
+                let tool_use_counts = tool_use_counts.clone();
                 move |thread, event: &ThreadEvent, cx| {
                     let mut log_file = log_file.lock().unwrap();
 
@@ -327,8 +333,11 @@ impl Example {
                                 writeln!(&mut log_file, "\n{}", message).log_err();
                             }
                             if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
-                                let message = format!("\n{}\n", tool_result.content);
-                                writeln!(&mut log_file, "{}", message).log_err();
+                                writeln!(&mut log_file, "\n{}\n", tool_result.content).log_err();
+                                let mut tool_use_counts = tool_use_counts.lock().unwrap();
+                                *tool_use_counts
+                                    .entry(tool_result.tool_name.clone())
+                                    .or_insert(0) += 1;
                             }
                         }
                         _ => {}
@@ -357,6 +366,7 @@ impl Example {
                 })?
                 .await?;
 
+            drop(subscription);
             drop(lsp_open_handle_and_store);
 
             thread.update(cx, |thread, _cx| {
@@ -369,6 +379,7 @@ impl Example {
                     diagnostics,
                     response_count,
                     token_usage: thread.cumulative_token_usage(),
+                    tool_use_counts: tool_use_counts.lock().unwrap().clone(),
                 }
             })
         })