eval: New `add_arg_to_trait_method` example (#29297)

Agus Zubiaga and Richard Feldman created

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>

Change summary

Cargo.lock                                          |   1 
crates/assistant_tools/src/assistant_tools.rs       |   3 
crates/eval/Cargo.toml                              |   1 
crates/eval/src/example.rs                          |  71 ++++++
crates/eval/src/examples/add_arg_to_trait_method.rs | 147 ++++++++++++++
crates/eval/src/examples/file_search.rs             |   2 
crates/eval/src/examples/mod.rs                     |   6 
7 files changed, 222 insertions(+), 9 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4961,6 +4961,7 @@ dependencies = [
  "assistant_tools",
  "async-trait",
  "async-watch",
+ "buffer_diff",
  "chrono",
  "clap",
  "client",

crates/assistant_tools/src/assistant_tools.rs 🔗

@@ -56,7 +56,10 @@ use crate::symbol_info_tool::SymbolInfoTool;
 use crate::terminal_tool::TerminalTool;
 use crate::thinking_tool::ThinkingTool;
 
+pub use create_file_tool::CreateFileToolInput;
+pub use edit_file_tool::EditFileToolInput;
 pub use path_search_tool::PathSearchToolInput;
+pub use read_file_tool::ReadFileToolInput;
 
 pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
     assistant_tool::init(cx);

crates/eval/Cargo.toml 🔗

@@ -11,6 +11,7 @@ assistant_tool.workspace = true
 assistant_tools.workspace = true
 async-trait.workspace = true
 async-watch.workspace = true
+buffer_diff.workspace = true
 chrono.workspace = true
 clap.workspace = true
 client.workspace = true

crates/eval/src/example.rs 🔗

@@ -1,6 +1,7 @@
 use std::{
     error::Error,
     fmt::{self, Debug},
+    path::Path,
     sync::{Arc, Mutex},
     time::Duration,
 };
@@ -12,6 +13,8 @@ use crate::{
 use agent::ThreadEvent;
 use anyhow::{Result, anyhow};
 use async_trait::async_trait;
+use buffer_diff::DiffHunkStatus;
+use collections::HashMap;
 use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 use gpui::{AppContext, AsyncApp, Entity};
 use language_model::{LanguageModel, Role, StopReason};
@@ -234,9 +237,9 @@ impl ExampleContext {
                             let mut tool_metrics = tool_metrics.lock().unwrap();
                             if let Some(tool_result) = thread.tool_result(&tool_use_id) {
                                 let message = if tool_result.is_error {
-                                    format!("TOOL FAILED: {}", tool_use.name)
+                                    format!("✖︎ {}", tool_use.name)
                                 } else {
-                                    format!("TOOL FINISHED: {}", tool_use.name)
+                                    format!("✔︎ {}", tool_use.name)
                                 };
                                 println!("{log_prefix}{message}");
                                 tool_metrics
@@ -320,6 +323,36 @@ impl ExampleContext {
 
         Ok(response)
     }
+
+    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
+        self.app
+            .read_entity(&self.agent_thread, |thread, cx| {
+                let action_log = thread.action_log().read(cx);
+                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
+                    |(buffer, diff)| {
+                        let snapshot = buffer.read(cx).snapshot();
+
+                        let file = snapshot.file().unwrap();
+                        let diff = diff.read(cx);
+                        let base_text = diff.base_text().text();
+
+                        let hunks = diff
+                            .hunks(&snapshot, cx)
+                            .map(|hunk| FileEditHunk {
+                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
+                                text: snapshot
+                                    .text_for_range(hunk.range.clone())
+                                    .collect::<String>(),
+                                status: hunk.status(),
+                            })
+                            .collect();
+
+                        (file.path().clone(), FileEdits { hunks })
+                    },
+                ))
+            })
+            .unwrap()
+    }
 }
 
 #[derive(Debug)]
@@ -344,6 +377,10 @@ impl Response {
         });
         cx.assert_some(result, format!("called `{}`", tool_name))
     }
+
+    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
+        self.messages.iter().flat_map(|msg| &msg.tool_use)
+    }
 }
 
 #[derive(Debug)]
@@ -355,17 +392,37 @@ pub struct Message {
 
 #[derive(Debug)]
 pub struct ToolUse {
-    name: String,
+    pub name: String,
     value: serde_json::Value,
 }
 
 impl ToolUse {
-    pub fn expect_input<Input>(&self, cx: &mut ExampleContext) -> Result<Input>
+    pub fn parse_input<Input>(&self) -> Result<Input>
     where
         Input: for<'de> serde::Deserialize<'de>,
     {
-        let result =
-            serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err));
-        cx.log_assertion(result, format!("valid `{}` input", &self.name))
+        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
+    }
+}
+
+#[derive(Debug)]
+pub struct FileEdits {
+    hunks: Vec<FileEditHunk>,
+}
+
+#[derive(Debug)]
+struct FileEditHunk {
+    base_text: String,
+    text: String,
+    status: DiffHunkStatus,
+}
+
+impl FileEdits {
+    pub fn has_added_line(&self, line: &str) -> bool {
+        self.hunks.iter().any(|hunk| {
+            hunk.status == DiffHunkStatus::added_none()
+                && hunk.base_text.is_empty()
+                && hunk.text.contains(line)
+        })
     }
 }

crates/eval/src/examples/add_arg_to_trait_method.rs 🔗

@@ -0,0 +1,147 @@
+use std::{collections::HashSet, path::Path};
+
+use anyhow::Result;
+use assistant_tools::{CreateFileToolInput, EditFileToolInput, ReadFileToolInput};
+use async_trait::async_trait;
+
+use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
+
+pub struct AddArgToTraitMethod;
+
+#[async_trait(?Send)]
+impl Example for AddArgToTraitMethod {
+    fn meta(&self) -> ExampleMetadata {
+        ExampleMetadata {
+            name: "add_arg_to_trait_method".to_string(),
+            url: "https://github.com/zed-industries/zed.git".to_string(),
+            revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
+            language_server: Some(LanguageServer {
+                file_extension: "rs".to_string(),
+                allow_preexisting_diagnostics: false,
+            }),
+            max_assertions: None,
+        }
+    }
+
+    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
+        const FILENAME: &str = "assistant_tool.rs";
+        cx.push_user_message(format!(
+            r#"
+            Add a `window: Option<gpui::AnyWindowHandle>` argument to the `Tool::run` trait method in {FILENAME},
+            and update all the implementations of the trait and call sites accordingly.
+            "#
+        ));
+
+        let response = cx.run_to_end().await?;
+
+        // Reads files before it edits them
+
+        let mut read_files = HashSet::new();
+
+        for tool_use in response.tool_uses() {
+            match tool_use.name.as_str() {
+                "read_file" => {
+                    if let Ok(input) = tool_use.parse_input::<ReadFileToolInput>() {
+                        read_files.insert(input.path);
+                    }
+                }
+                "create_file" => {
+                    if let Ok(input) = tool_use.parse_input::<CreateFileToolInput>() {
+                        read_files.insert(input.path);
+                    }
+                }
+                "edit_file" => {
+                    if let Ok(input) = tool_use.parse_input::<EditFileToolInput>() {
+                        cx.assert(
+                            read_files.contains(input.path.to_str().unwrap()),
+                            format!(
+                                "Read before edit: {}",
+                                &input.path.file_stem().unwrap().to_str().unwrap()
+                            ),
+                        )
+                        .ok();
+                    }
+                }
+                _ => {}
+            }
+        }
+
+        // Adds ignored argument to all but `batch_tool`
+
+        let add_ignored_window_paths = &[
+            "code_action_tool",
+            "code_symbols_tool",
+            "contents_tool",
+            "copy_path_tool",
+            "create_directory_tool",
+            "create_file_tool",
+            "delete_path_tool",
+            "diagnostics_tool",
+            "edit_file_tool",
+            "fetch_tool",
+            "grep_tool",
+            "list_directory_tool",
+            "move_path_tool",
+            "now_tool",
+            "open_tool",
+            "path_search_tool",
+            "read_file_tool",
+            "rename_tool",
+            "symbol_info_tool",
+            "terminal_tool",
+            "thinking_tool",
+            "web_search_tool",
+        ];
+
+        let edits = cx.edits();
+
+        for tool_name in add_ignored_window_paths {
+            let path_str = format!("crates/assistant_tools/src/{}.rs", tool_name);
+            let edits = edits.get(Path::new(&path_str));
+
+            let ignored = edits.map_or(false, |edits| {
+                edits.has_added_line("        _window: Option<gpui::AnyWindowHandle>,\n")
+            });
+            let uningored = edits.map_or(false, |edits| {
+                edits.has_added_line("        window: Option<gpui::AnyWindowHandle>,\n")
+            });
+
+            cx.assert(ignored || uningored, format!("Argument:   {}", tool_name))
+                .ok();
+
+            cx.assert(ignored, format!("`_` prefix: {}", tool_name))
+                .ok();
+        }
+
+        // Adds unignored argument to `batch_tool`
+
+        let batch_tool_edits = edits.get(Path::new("crates/assistant_tools/src/batch_tool.rs"));
+
+        cx.assert(
+            batch_tool_edits.map_or(false, |edits| {
+                edits.has_added_line("        window: Option<gpui::AnyWindowHandle>,\n")
+            }),
+            "Argument:   batch_tool",
+        )
+        .ok();
+
+        Ok(())
+    }
+
+    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
+        vec![
+            JudgeAssertion {
+                id: "batch tool passes window to each".to_string(),
+                description:
+                    "batch_tool is modified to pass a clone of the window to each tool it calls."
+                        .to_string(),
+            },
+            JudgeAssertion {
+                id: "tool tests updated".to_string(),
+                description:
+                    "tool tests are updated to pass the new `window` argument (`None` is ok)."
+                        .to_string(),
+            },
+        ]
+    }
+}

crates/eval/src/examples/file_search.rs 🔗

@@ -33,7 +33,7 @@ impl Example for FileSearchExample {
 
         let response = cx.run_turn().await?;
         let tool_use = response.expect_tool("path_search", cx)?;
-        let input = tool_use.expect_input::<PathSearchToolInput>(cx)?;
+        let input = tool_use.parse_input::<PathSearchToolInput>()?;
 
         let glob = input.glob;
         cx.assert(

crates/eval/src/examples/mod.rs 🔗

@@ -11,10 +11,14 @@ use util::serde::default_true;
 
 use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion};
 
+mod add_arg_to_trait_method;
 mod file_search;
 
 pub fn all(examples_dir: &Path) -> Vec<Rc<dyn Example>> {
-    let mut threads: Vec<Rc<dyn Example>> = vec![Rc::new(file_search::FileSearchExample)];
+    let mut threads: Vec<Rc<dyn Example>> = vec![
+        Rc::new(file_search::FileSearchExample),
+        Rc::new(add_arg_to_trait_method::AddArgToTraitMethod),
+    ];
 
     for example_path in list_declarative_examples(examples_dir).unwrap() {
         threads.push(Rc::new(DeclarativeExample::load(&example_path).unwrap()));