Add initial implementation of evaluating changes generated by the assistant (#26799)

Michael Sloan , Richard Feldman , and Thomas created

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Thomas <thomas@zed.dev>

Change summary

Cargo.lock                                      |  35 ++
Cargo.toml                                      |   2 
crates/assistant2/src/active_thread.rs          |   1 
crates/assistant2/src/assistant.rs              |   3 
crates/assistant2/src/thread.rs                 |  58 ++-
crates/assistant2/src/tool_use.rs               |   7 
crates/assistant_eval/Cargo.toml                |  44 +++
crates/assistant_eval/LICENSE-GPL               |   1 
crates/assistant_eval/README.md                 |  77 +++++
crates/assistant_eval/build.rs                  |  61 ++++
crates/assistant_eval/src/eval.rs               | 252 +++++++++++++++++++
crates/assistant_eval/src/headless_assistant.rs | 241 ++++++++++++++++++
crates/assistant_eval/src/judge.rs              | 121 +++++++++
crates/assistant_eval/src/main.rs               | 234 +++++++++++++++++
14 files changed, 1,113 insertions(+), 24 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -566,6 +566,41 @@ dependencies = [
  "workspace",
 ]
 
+[[package]]
+name = "assistant_eval"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "assistant2",
+ "assistant_tool",
+ "assistant_tools",
+ "clap",
+ "client",
+ "collections",
+ "context_server",
+ "env_logger 0.11.7",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "gpui_tokio",
+ "itertools 0.14.0",
+ "language",
+ "language_model",
+ "language_models",
+ "node_runtime",
+ "project",
+ "prompt_store",
+ "regex",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "serde_json_lenient",
+ "settings",
+ "smol",
+ "util",
+]
+
 [[package]]
 name = "assistant_settings"
 version = "0.1.0"

Cargo.toml 🔗

@@ -8,6 +8,7 @@ members = [
     "crates/assistant",
     "crates/assistant2",
     "crates/assistant_context_editor",
+    "crates/assistant_eval",
     "crates/assistant_settings",
     "crates/assistant_slash_command",
     "crates/assistant_slash_commands",
@@ -206,6 +207,7 @@ assets = { path = "crates/assets" }
 assistant = { path = "crates/assistant" }
 assistant2 = { path = "crates/assistant2" }
 assistant_context_editor = { path = "crates/assistant_context_editor" }
+assistant_eval = { path = "crates/assistant_eval" }
 assistant_settings = { path = "crates/assistant_settings" }
 assistant_slash_command = { path = "crates/assistant_slash_command" }
 assistant_slash_commands = { path = "crates/assistant_slash_commands" }

crates/assistant2/src/active_thread.rs 🔗

@@ -298,6 +298,7 @@ impl ActiveThread {
             ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
                 self.save_thread(cx);
             }
+            ThreadEvent::DoneStreaming => {}
             ThreadEvent::StreamedAssistantText(message_id, text) => {
                 if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
                     markdown.update(cx, |markdown, cx| {

crates/assistant2/src/assistant.rs 🔗

@@ -31,8 +31,11 @@ use gpui::{actions, App};
 use prompt_store::PromptBuilder;
 use settings::Settings as _;
 
+pub use crate::active_thread::ActiveThread;
 pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
 pub use crate::inline_assistant::InlineAssistant;
+pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
+pub use crate::thread_store::ThreadStore;
 
 actions!(
     assistant2,

crates/assistant2/src/thread.rs 🔗

@@ -284,6 +284,10 @@ impl Thread {
         self.tool_use.tool_results_for_message(id)
     }
 
+    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
+        self.tool_use.tool_result(id)
+    }
+
     pub fn scripting_tool_results_for_message(
         &self,
         id: MessageId,
@@ -652,32 +656,37 @@ impl Thread {
             let result = stream_completion.await;
 
             thread
-                .update(&mut cx, |thread, cx| match result.as_ref() {
-                    Ok(stop_reason) => match stop_reason {
-                        StopReason::ToolUse => {
-                            cx.emit(ThreadEvent::UsePendingTools);
-                        }
-                        StopReason::EndTurn => {}
-                        StopReason::MaxTokens => {}
-                    },
-                    Err(error) => {
-                        if error.is::<PaymentRequiredError>() {
-                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
-                        } else if error.is::<MaxMonthlySpendReachedError>() {
-                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
-                        } else {
-                            let error_message = error
-                                .chain()
-                                .map(|err| err.to_string())
-                                .collect::<Vec<_>>()
-                                .join("\n");
-                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
-                                SharedString::from(error_message.clone()),
-                            )));
-                        }
+                .update(&mut cx, |thread, cx| {
+                    match result.as_ref() {
+                        Ok(stop_reason) => match stop_reason {
+                            StopReason::ToolUse => {
+                                cx.emit(ThreadEvent::UsePendingTools);
+                            }
+                            StopReason::EndTurn => {}
+                            StopReason::MaxTokens => {}
+                        },
+                        Err(error) => {
+                            if error.is::<PaymentRequiredError>() {
+                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
+                            } else if error.is::<MaxMonthlySpendReachedError>() {
+                                cx.emit(ThreadEvent::ShowError(
+                                    ThreadError::MaxMonthlySpendReached,
+                                ));
+                            } else {
+                                let error_message = error
+                                    .chain()
+                                    .map(|err| err.to_string())
+                                    .collect::<Vec<_>>()
+                                    .join("\n");
+                                cx.emit(ThreadEvent::ShowError(ThreadError::Message(
+                                    SharedString::from(error_message.clone()),
+                                )));
+                            }
 
-                        thread.cancel_last_completion();
+                            thread.cancel_last_completion();
+                        }
                     }
+                    cx.emit(ThreadEvent::DoneStreaming);
                 })
                 .ok();
         });
@@ -1094,6 +1103,7 @@ pub enum ThreadEvent {
     ShowError(ThreadError),
     StreamedCompletion,
     StreamedAssistantText(MessageId, String),
+    DoneStreaming,
     MessageAdded(MessageId),
     MessageEdited(MessageId),
     MessageDeleted(MessageId),

crates/assistant2/src/tool_use.rs 🔗

@@ -182,6 +182,13 @@ impl ToolUseState {
             .map_or(false, |results| !results.is_empty())
     }
 
+    pub fn tool_result(
+        &self,
+        tool_use_id: &LanguageModelToolUseId,
+    ) -> Option<&LanguageModelToolResult> {
+        self.tool_results.get(tool_use_id)
+    }
+
     pub fn request_tool_use(
         &mut self,
         assistant_message_id: MessageId,

crates/assistant_eval/Cargo.toml 🔗

@@ -0,0 +1,44 @@
+[package]
+name = "assistant_eval"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[[bin]]
+name = "assistant_eval"
+path = "src/main.rs"
+
+[dependencies]
+anyhow.workspace = true
+assistant2.workspace = true
+assistant_tool.workspace = true
+assistant_tools.workspace = true
+clap.workspace = true
+client.workspace = true
+collections.workspace = true
+context_server.workspace = true
+env_logger.workspace = true
+fs.workspace = true
+futures.workspace = true
+gpui.workspace = true
+gpui_tokio.workspace = true
+itertools.workspace = true
+language.workspace = true
+language_model.workspace = true
+language_models.workspace = true
+node_runtime.workspace = true
+project.workspace = true
+prompt_store.workspace = true
+regex.workspace = true
+release_channel.workspace = true
+reqwest_client.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+serde_json_lenient.workspace = true
+settings.workspace = true
+smol.workspace = true
+util.workspace = true

crates/assistant_eval/README.md 🔗

@@ -0,0 +1,77 @@
+# Tool Evals
+
+A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
+
+## Overview
+
+Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
+
+1. Cloning and setting up test repositories
+2. Sending prompts to language models
+3. Allowing the assistant to use tools to modify code
+4. Collecting metrics on performance
+5. Evaluating results against known good solutions
+
+## How It Works
+
+The system consists of several key components:
+
+- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
+- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
+- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
+
+The evaluation flow:
+1. An evaluation is loaded from the evaluation_data directory
+2. The target repository is cloned and checked out at a specific commit
+3. A HeadlessAssistant instance is created with the specified language model
+4. The user prompt is sent to the assistant
+5. The assistant responds and uses tools to modify code
+6. Upon completion, a diff is generated from the changes
+7. Results are saved including the diff, assistant's response, and performance metrics
+8. If a reference solution exists, a Judge evaluates the similarity of the solution
+
+## Setup Requirements
+
+### Prerequisites
+
+- Rust and Cargo
+- Git
+- Network access to clone repositories
+- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
+
+### Environment Variables
+
+Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
+- `ZED_ANTHROPIC_API_KEY` for Claude models
+- `ZED_OPENAI_API_KEY` for OpenAI models
+- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
+
+## Usage
+
+### Running a Single Evaluation
+
+To run a specific evaluation:
+
+```bash
+cargo run -p assistant_eval -- bubbletea-add-set-window-title
+```
+
+The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
+
+```bash
+cargo run -p assistant_eval -- bubbletea
+```
+
+To run all evaluations:
+
+```bash
+cargo run -p assistant_eval -- --all
+```
+
+## Evaluation Data Structure
+
+Each evaluation should be placed in the `evaluation_data` directory with the following structure:
+
+* `prompt.txt`: The user's prompt.
+* `original.diff`: The `git diff` of the change anticipated for this prompt.
+* `setup.json`: Information about the repo used for the evaluation.

crates/assistant_eval/build.rs 🔗

@@ -0,0 +1,61 @@
+// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
+
+use std::process::Command;
+
+fn main() {
+    if cfg!(target_os = "macos") {
+        println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
+
+        println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
+        if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
+            // Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
+            println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
+        } else {
+            // Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
+            println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
+        }
+
+        // Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
+        println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
+
+        // Seems to be required to enable Swift concurrency
+        println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
+
+        // Register exported Objective-C selectors, protocols, etc
+        println!("cargo:rustc-link-arg=-Wl,-ObjC");
+    }
+
+    // Populate git sha environment variable if git is available
+    println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
+    println!(
+        "cargo:rustc-env=TARGET={}",
+        std::env::var("TARGET").unwrap()
+    );
+    if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
+        if output.status.success() {
+            let git_sha = String::from_utf8_lossy(&output.stdout);
+            let git_sha = git_sha.trim();
+
+            println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
+
+            if let Ok(build_profile) = std::env::var("PROFILE") {
+                if build_profile == "release" {
+                    // This is currently the best way to make `cargo build ...`'s build script
+                    // to print something to stdout without extra verbosity.
+                    println!(
+                        "cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
+                    );
+                }
+            }
+        }
+    }
+
+    #[cfg(target_os = "windows")]
+    {
+        #[cfg(target_env = "msvc")]
+        {
+            // todo(windows): This is to avoid stack overflow. Remove it when solved.
+            println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
+        }
+    }
+}

crates/assistant_eval/src/eval.rs 🔗

@@ -0,0 +1,252 @@
+use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
+use anyhow::anyhow;
+use assistant2::RequestKind;
+use collections::HashMap;
+use gpui::{App, Task};
+use language_model::{LanguageModel, TokenUsage};
+use serde::{Deserialize, Serialize};
+use std::{
+    fs,
+    io::Write,
+    path::{Path, PathBuf},
+    sync::Arc,
+    time::Duration,
+};
+use util::command::new_smol_command;
+
+pub struct Eval {
+    pub name: String,
+    pub path: PathBuf,
+    pub repo_path: PathBuf,
+    pub eval_setup: EvalSetup,
+    pub user_prompt: String,
+}
+
+#[derive(Debug, Serialize)]
+pub struct EvalOutput {
+    pub diff: String,
+    pub last_message: String,
+    pub elapsed_time: Duration,
+    pub assistant_response_count: usize,
+    pub tool_use_counts: HashMap<Arc<str>, u32>,
+    pub token_usage: TokenUsage,
+}
+
+#[derive(Deserialize)]
+pub struct EvalSetup {
+    pub url: String,
+    pub base_sha: String,
+}
+
+impl Eval {
+    /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
+    /// if necessary.
+    pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
+        let prompt_path = path.join("prompt.txt");
+        let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
+        let setup_path = path.join("setup.json");
+        let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
+        let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
+        let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
+        Ok(Eval {
+            name,
+            path,
+            repo_path,
+            eval_setup,
+            user_prompt,
+        })
+    }
+
+    pub fn run(
+        self,
+        app_state: Arc<HeadlessAppState>,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut App,
+    ) -> Task<anyhow::Result<EvalOutput>> {
+        cx.spawn(move |mut cx| async move {
+            checkout_repo(&self.eval_setup, &self.repo_path).await?;
+
+            let (assistant, done_rx) =
+                cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
+
+            let _worktree = assistant
+                .update(&mut cx, |assistant, cx| {
+                    assistant.project.update(cx, |project, cx| {
+                        project.create_worktree(&self.repo_path, true, cx)
+                    })
+                })?
+                .await?;
+
+            let start_time = std::time::SystemTime::now();
+
+            assistant.update(&mut cx, |assistant, cx| {
+                assistant.thread.update(cx, |thread, cx| {
+                    let context = vec![];
+                    thread.insert_user_message(self.user_prompt.clone(), context, cx);
+                    thread.send_to_model(model, RequestKind::Chat, cx);
+                });
+            })?;
+
+            done_rx.recv().await??;
+
+            let elapsed_time = start_time.elapsed()?;
+
+            let diff = query_git(&self.repo_path, vec!["diff"]).await?;
+
+            assistant.update(&mut cx, |assistant, cx| {
+                let thread = assistant.thread.read(cx);
+                let last_message = thread.messages().last().unwrap();
+                if last_message.role != language_model::Role::Assistant {
+                    return Err(anyhow!("Last message is not from assistant"));
+                }
+                let assistant_response_count = thread
+                    .messages()
+                    .filter(|message| message.role == language_model::Role::Assistant)
+                    .count();
+                Ok(EvalOutput {
+                    diff,
+                    last_message: last_message.text.clone(),
+                    elapsed_time,
+                    assistant_response_count,
+                    tool_use_counts: assistant.tool_use_counts.clone(),
+                    token_usage: thread.cumulative_token_usage(),
+                })
+            })?
+        })
+    }
+}
+
+impl EvalOutput {
+    // Method to save the output to a directory
+    pub fn save_to_directory(
+        &self,
+        output_dir: &Path,
+        eval_output_value: String,
+    ) -> anyhow::Result<()> {
+        // Create the output directory if it doesn't exist
+        fs::create_dir_all(&output_dir)?;
+
+        // Save the diff to a file
+        let diff_path = output_dir.join("diff.patch");
+        let mut diff_file = fs::File::create(&diff_path)?;
+        diff_file.write_all(self.diff.as_bytes())?;
+
+        // Save the last message to a file
+        let message_path = output_dir.join("assistant_response.txt");
+        let mut message_file = fs::File::create(&message_path)?;
+        message_file.write_all(self.last_message.as_bytes())?;
+
+        // Current metrics for this run
+        let current_metrics = serde_json::json!({
+            "elapsed_time_ms": self.elapsed_time.as_millis(),
+            "assistant_response_count": self.assistant_response_count,
+            "tool_use_counts": self.tool_use_counts,
+            "token_usage": self.token_usage,
+            "eval_output_value": eval_output_value,
+        });
+
+        // Get current timestamp in milliseconds
+        let timestamp = std::time::SystemTime::now()
+            .duration_since(std::time::UNIX_EPOCH)?
+            .as_millis()
+            .to_string();
+
+        // Path to metrics file
+        let metrics_path = output_dir.join("metrics.json");
+
+        // Load existing metrics if the file exists, or create a new object
+        let mut historical_metrics = if metrics_path.exists() {
+            let metrics_content = fs::read_to_string(&metrics_path)?;
+            serde_json::from_str::<serde_json::Value>(&metrics_content)
+                .unwrap_or_else(|_| serde_json::json!({}))
+        } else {
+            serde_json::json!({})
+        };
+
+        // Add new run with timestamp as key
+        if let serde_json::Value::Object(ref mut map) = historical_metrics {
+            map.insert(timestamp, current_metrics);
+        }
+
+        // Write updated metrics back to file
+        let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
+        let mut metrics_file = fs::File::create(&metrics_path)?;
+        metrics_file.write_all(metrics_json.as_bytes())?;
+
+        Ok(())
+    }
+}
+
+fn repo_dir_name(url: &str) -> String {
+    url.trim_start_matches("https://")
+        .replace(|c: char| !c.is_alphanumeric(), "_")
+}
+
+async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
+    if !repo_path.exists() {
+        smol::unblock({
+            let repo_path = repo_path.to_path_buf();
+            || std::fs::create_dir_all(repo_path)
+        })
+        .await?;
+        run_git(repo_path, vec!["init"]).await?;
+        run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
+    } else {
+        let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
+        if actual_origin != eval_setup.url {
+            return Err(anyhow!(
+                "remote origin {} does not match expected origin {}",
+                actual_origin,
+                eval_setup.url
+            ));
+        }
+
+        // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
+        // also remove build artifacts, and so prevent incremental reuse there.
+        run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
+        run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
+    }
+
+    run_git(
+        repo_path,
+        vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
+    )
+    .await?;
+    run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
+
+    Ok(())
+}
+
+async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
+    let exit_status = new_smol_command("git")
+        .current_dir(repo_path)
+        .args(args.clone())
+        .status()
+        .await?;
+    if exit_status.success() {
+        Ok(())
+    } else {
+        Err(anyhow!(
+            "`git {}` failed with {}",
+            args.join(" "),
+            exit_status,
+        ))
+    }
+}
+
+async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
+    let output = new_smol_command("git")
+        .current_dir(repo_path)
+        .args(args.clone())
+        .output()
+        .await?;
+    if output.status.success() {
+        Ok(String::from_utf8(output.stdout)?.trim().to_string())
+    } else {
+        Err(anyhow!(
+            "`git {}` failed with {}",
+            args.join(" "),
+            output.status
+        ))
+    }
+}

crates/assistant_eval/src/headless_assistant.rs 🔗

@@ -0,0 +1,241 @@
+use anyhow::anyhow;
+use assistant2::{Thread, ThreadEvent, ThreadStore};
+use assistant_tool::ToolWorkingSet;
+use client::{Client, UserStore};
+use collections::HashMap;
+use futures::StreamExt;
+use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
+use language::LanguageRegistry;
+use language_model::{
+    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+    LanguageModelRequest,
+};
+use node_runtime::NodeRuntime;
+use project::{Project, RealFs};
+use prompt_store::PromptBuilder;
+use settings::SettingsStore;
+use smol::channel;
+use std::sync::Arc;
+
+/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
+pub struct HeadlessAppState {
+    pub languages: Arc<LanguageRegistry>,
+    pub client: Arc<Client>,
+    pub user_store: Entity<UserStore>,
+    pub fs: Arc<dyn fs::Fs>,
+    pub node_runtime: NodeRuntime,
+
+    // Additional fields not present in `workspace::AppState`.
+    pub prompt_builder: Arc<PromptBuilder>,
+}
+
+pub struct HeadlessAssistant {
+    pub thread: Entity<Thread>,
+    pub project: Entity<Project>,
+    #[allow(dead_code)]
+    pub thread_store: Entity<ThreadStore>,
+    pub tool_use_counts: HashMap<Arc<str>, u32>,
+    pub done_tx: channel::Sender<anyhow::Result<()>>,
+    _subscription: Subscription,
+}
+
+impl HeadlessAssistant {
+    pub fn new(
+        app_state: Arc<HeadlessAppState>,
+        cx: &mut App,
+    ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
+        let env = None;
+        let project = Project::local(
+            app_state.client.clone(),
+            app_state.node_runtime.clone(),
+            app_state.user_store.clone(),
+            app_state.languages.clone(),
+            app_state.fs.clone(),
+            env,
+            cx,
+        );
+
+        let tools = Arc::new(ToolWorkingSet::default());
+        let thread_store =
+            ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
+
+        let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+
+        let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
+
+        let headless_thread = cx.new(move |cx| Self {
+            _subscription: cx.subscribe(&thread, Self::handle_thread_event),
+            thread,
+            project,
+            thread_store,
+            tool_use_counts: HashMap::default(),
+            done_tx,
+        });
+
+        Ok((headless_thread, done_rx))
+    }
+
+    fn handle_thread_event(
+        &mut self,
+        thread: Entity<Thread>,
+        event: &ThreadEvent,
+        cx: &mut Context<Self>,
+    ) {
+        match event {
+            ThreadEvent::ShowError(err) => self
+                .done_tx
+                .send_blocking(Err(anyhow!("{:?}", err)))
+                .unwrap(),
+            ThreadEvent::DoneStreaming => {
+                let thread = thread.read(cx);
+                if let Some(message) = thread.messages().last() {
+                    println!("Message: {}", message.text,);
+                }
+                if thread.all_tools_finished() {
+                    self.done_tx.send_blocking(Ok(())).unwrap()
+                }
+            }
+            ThreadEvent::UsePendingTools => {
+                thread.update(cx, |thread, cx| {
+                    thread.use_pending_tools(cx);
+                });
+            }
+            ThreadEvent::ToolFinished {
+                tool_use_id,
+                pending_tool_use,
+            } => {
+                if let Some(pending_tool_use) = pending_tool_use {
+                    println!(
+                        "Used tool {} with input: {}",
+                        pending_tool_use.name, pending_tool_use.input
+                    );
+                    *self
+                        .tool_use_counts
+                        .entry(pending_tool_use.name.clone())
+                        .or_insert(0) += 1;
+                }
+                if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
+                    println!("Tool result: {:?}", tool_result);
+                }
+                if thread.read(cx).all_tools_finished() {
+                    let model_registry = LanguageModelRegistry::read_global(cx);
+                    if let Some(model) = model_registry.active_model() {
+                        thread.update(cx, |thread, cx| {
+                            // Currently evals do not support specifying context.
+                            let updated_context = vec![];
+                            thread.send_tool_results_to_model(model, updated_context, cx);
+                        });
+                    }
+                }
+            }
+            ThreadEvent::StreamedCompletion
+            | ThreadEvent::SummaryChanged
+            | ThreadEvent::StreamedAssistantText(_, _)
+            | ThreadEvent::MessageAdded(_)
+            | ThreadEvent::MessageEdited(_)
+            | ThreadEvent::MessageDeleted(_) => {}
+        }
+    }
+}
+
+pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
+    release_channel::init(SemanticVersion::default(), cx);
+    gpui_tokio::init(cx);
+
+    let mut settings_store = SettingsStore::new(cx);
+    settings_store
+        .set_default_settings(settings::default_settings().as_ref(), cx)
+        .unwrap();
+    cx.set_global(settings_store);
+    client::init_settings(cx);
+    Project::init_settings(cx);
+
+    let client = Client::production(cx);
+    cx.set_http_client(client.http_client().clone());
+
+    let git_binary_path = None;
+    let fs = Arc::new(RealFs::new(git_binary_path));
+
+    let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+
+    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+    language::init(cx);
+    language_model::init(client.clone(), cx);
+    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+    assistant_tools::init(cx);
+    context_server::init(cx);
+    let stdout_is_a_pty = false;
+    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
+    assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
+
+    Arc::new(HeadlessAppState {
+        languages,
+        client,
+        user_store,
+        fs,
+        node_runtime: NodeRuntime::unavailable(),
+        prompt_builder,
+    })
+}
+
+pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
+    let model_registry = LanguageModelRegistry::read_global(cx);
+    let model = model_registry
+        .available_models(cx)
+        .find(|model| model.id().0 == model_name);
+
+    let Some(model) = model else {
+        return Err(anyhow!(
+            "No language model named {} was available. Available models: {}",
+            model_name,
+            model_registry
+                .available_models(cx)
+                .map(|model| model.id().0.clone())
+                .collect::<Vec<_>>()
+                .join(", ")
+        ));
+    };
+
+    Ok(model)
+}
+
+pub fn authenticate_model_provider(
+    provider_id: LanguageModelProviderId,
+    cx: &mut App,
+) -> Task<std::result::Result<(), AuthenticateError>> {
+    let model_registry = LanguageModelRegistry::read_global(cx);
+    let model_provider = model_registry.provider(&provider_id).unwrap();
+    model_provider.authenticate(cx)
+}
+
+pub async fn send_language_model_request(
+    model: Arc<dyn LanguageModel>,
+    request: LanguageModelRequest,
+    cx: AsyncApp,
+) -> anyhow::Result<String> {
+    match model.stream_completion_text(request, &cx).await {
+        Ok(mut stream) => {
+            let mut full_response = String::new();
+
+            // Process the response stream
+            while let Some(chunk_result) = stream.stream.next().await {
+                match chunk_result {
+                    Ok(chunk_str) => {
+                        full_response.push_str(&chunk_str);
+                    }
+                    Err(err) => {
+                        return Err(anyhow!(
+                            "Error receiving response from language model: {err}"
+                        ));
+                    }
+                }
+            }
+
+            Ok(full_response)
+        }
+        Err(err) => Err(anyhow!(
+            "Failed to get response from language model. Error was: {err}"
+        )),
+    }
+}

crates/assistant_eval/src/judge.rs 🔗

@@ -0,0 +1,121 @@
+use crate::eval::EvalOutput;
+use crate::headless_assistant::send_language_model_request;
+use anyhow::anyhow;
+use gpui::{App, Task};
+use language_model::{
+    LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
+};
+use std::{path::Path, sync::Arc};
+
+pub struct Judge {
+    pub original_diff: Option<String>,
+    #[allow(dead_code)]
+    pub original_message: Option<String>,
+    pub model: Arc<dyn LanguageModel>,
+}
+
+impl Judge {
+    pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
+        let original_diff_path = eval_path.join("original.diff");
+        let original_diff = smol::unblock(move || {
+            if std::fs::exists(&original_diff_path)? {
+                anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
+            } else {
+                anyhow::Ok(None)
+            }
+        });
+
+        let original_message_path = eval_path.join("original_message.txt");
+        let original_message = smol::unblock(move || {
+            if std::fs::exists(&original_message_path)? {
+                anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
+            } else {
+                anyhow::Ok(None)
+            }
+        });
+
+        Ok(Self {
+            original_diff: original_diff.await?,
+            original_message: original_message.await?,
+            model,
+        })
+    }
+
+    pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
+        let Some(original_diff) = self.original_diff.as_ref() else {
+            return Task::ready(Err(anyhow!("No original.diff found")));
+        };
+
+        // TODO: check for empty diff?
+        let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
+
+        let request = LanguageModelRequest {
+            messages: vec![LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec![MessageContent::Text(prompt)],
+                cache: false,
+            }],
+            temperature: Some(0.0),
+            tools: Vec::new(),
+            stop: Vec::new(),
+        };
+
+        let model = self.model.clone();
+        cx.spawn(move |cx| send_language_model_request(model, request, cx))
+    }
+}
+
+pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
+    format!(
+        r#"# Git Diff Similarity Evaluation Template
+
+## Instructions
+
+Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
+- 1.0 = Perfect functional match (achieves identical results)
+- 0.0 = No functional similarity whatsoever
+
+## Evaluation Criteria
+
+Please consider the following aspects in order of importance:
+
+1. **Functional Equivalence (60%)**
+   - Do both diffs achieve the same end result?
+   - Are the changes functionally equivalent despite possibly using different approaches?
+   - Do the modifications address the same issues or implement the same features?
+
+2. **Logical Structure (20%)**
+   - Are the logical flows similar?
+   - Do the modifications affect the same code paths?
+   - Are control structures (if/else, loops, etc.) modified in similar ways?
+
+3. **Code Content (15%)**
+   - Are similar lines added/removed?
+   - Are the same variables, functions, or methods being modified?
+   - Are the same APIs or libraries being used?
+
+4. **File Layout (5%)**
+   - Are the same files being modified?
+   - Are changes occurring in similar locations within files?
+
+## Input
+
+Original Diff:
+```git
+{}
+```
+
+New Diff:
+```git
+{}
+```
+
+## Output Format
+
+THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
+
+Example output:
+0.85"#,
+        original_diff, new_diff
+    )
+}

crates/assistant_eval/src/main.rs 🔗

@@ -0,0 +1,234 @@
+mod eval;
+mod headless_assistant;
+mod judge;
+
+use clap::Parser;
+use eval::{Eval, EvalOutput};
+use futures::{stream, StreamExt};
+use gpui::{Application, AsyncApp};
+use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
+use itertools::Itertools;
+use judge::Judge;
+use language_model::{LanguageModel, LanguageModelRegistry};
+use regex::Regex;
+use reqwest_client::ReqwestClient;
+use std::{cmp, path::PathBuf, sync::Arc};
+
+#[derive(Parser, Debug)]
+#[command(
+    name = "assistant_eval",
+    disable_version_flag = true,
+    before_help = "Tool eval runner"
+)]
+struct Args {
+    /// Regexes to match the names of evals to run.
+    eval_name_regexes: Vec<String>,
+    /// Runs all evals in `evaluation_data`, causes the regex to be ignored.
+    #[arg(long)]
+    all: bool,
+    /// Name of the model (default: "claude-3-7-sonnet-latest")
+    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+    model_name: String,
+    /// Name of the editor model (default: value of `--model_name`).
+    #[arg(long)]
+    editor_model_name: Option<String>,
+    /// Name of the judge model (default: value of `--model_name`).
+    #[arg(long)]
+    judge_model_name: Option<String>,
+    /// Number of evaluations to run concurrently (default: 10)
+    #[arg(short, long, default_value = "10")]
+    concurrency: usize,
+}
+
+fn main() {
+    env_logger::init();
+    let args = Args::parse();
+    let http_client = Arc::new(ReqwestClient::new());
+    let app = Application::headless().with_http_client(http_client.clone());
+
+    let crate_dir = PathBuf::from("../zed-agent-bench");
+    let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
+    let repos_dir = crate_dir.join("repos").canonicalize().unwrap();
+
+    let all_evals = std::fs::read_dir(&evaluation_data_dir)
+        .unwrap()
+        .map(|path| path.unwrap().file_name().to_string_lossy().to_string())
+        .collect::<Vec<_>>();
+
+    let evals_to_run = if args.all {
+        all_evals
+    } else {
+        args.eval_name_regexes
+            .into_iter()
+            .map(|regex_string| Regex::new(&regex_string).unwrap())
+            .flat_map(|regex| {
+                all_evals
+                    .iter()
+                    .filter(|eval_name| regex.is_match(eval_name))
+                    .cloned()
+                    .collect::<Vec<_>>()
+            })
+            .collect::<Vec<_>>()
+    };
+
+    if evals_to_run.is_empty() {
+        panic!("Names of evals to run must be provided or `--all` specified");
+    }
+
+    println!("Will run the following evals: {evals_to_run:?}");
+    println!("Running up to {} evals concurrently", args.concurrency);
+
+    let editor_model_name = if let Some(model_name) = args.editor_model_name {
+        model_name
+    } else {
+        args.model_name.clone()
+    };
+
+    let judge_model_name = if let Some(model_name) = args.judge_model_name {
+        model_name
+    } else {
+        args.model_name.clone()
+    };
+
+    app.run(move |cx| {
+        let app_state = headless_assistant::init(cx);
+
+        let model = find_model(&args.model_name, cx).unwrap();
+        let editor_model = find_model(&editor_model_name, cx).unwrap();
+        let judge_model = find_model(&judge_model_name, cx).unwrap();
+
+        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+            registry.set_active_model(Some(model.clone()), cx);
+            registry.set_editor_model(Some(editor_model.clone()), cx);
+        });
+
+        let model_provider_id = model.provider_id();
+        let editor_model_provider_id = editor_model.provider_id();
+        let judge_model_provider_id = judge_model.provider_id();
+
+        cx.spawn(move |cx| async move {
+            // Authenticate all model providers first
+            cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
+                .unwrap()
+                .await
+                .unwrap();
+            cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
+                .unwrap()
+                .await
+                .unwrap();
+            cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
+                .unwrap()
+                .await
+                .unwrap();
+
+            let loaded_evals = stream::iter(evals_to_run)
+                .map(|eval_name| {
+                    let eval_path = evaluation_data_dir.join(&eval_name);
+                    let repos_dir = repos_dir.clone();
+                    async move {
+                        match Eval::load(eval_name.clone(), eval_path, &repos_dir).await {
+                            Ok(eval) => Some(eval),
+                            Err(err) => {
+                                // TODO: Persist errors / surface errors at the end.
+                                println!("Error loading {eval_name}: {err}");
+                                None
+                            }
+                        }
+                    }
+                })
+                .buffer_unordered(args.concurrency)
+                .collect::<Vec<_>>()
+                .await
+                .into_iter()
+                .flatten()
+                .collect::<Vec<_>>();
+
+            // The evals need to be loaded and grouped by URL before concurrently running, since
+            // evals that use the same remote URL will use the same working directory.
+            let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
+                .into_iter()
+                .map(|eval| (eval.eval_setup.url.clone(), eval))
+                .into_group_map()
+                .into_values()
+                .collect::<Vec<_>>();
+
+            // Sort groups in descending order, so that bigger groups start first.
+            evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
+
+            let results = stream::iter(evals_grouped_by_url)
+                .map(|evals| {
+                    let model = model.clone();
+                    let judge_model = judge_model.clone();
+                    let app_state = app_state.clone();
+                    let cx = cx.clone();
+
+                    async move {
+                        let mut results = Vec::new();
+                        for eval in evals {
+                            let name = eval.name.clone();
+                            println!("Starting eval named {}", name);
+                            let result = run_eval(
+                                eval,
+                                model.clone(),
+                                judge_model.clone(),
+                                app_state.clone(),
+                                cx.clone(),
+                            )
+                            .await;
+                            results.push((name, result));
+                        }
+                        results
+                    }
+                })
+                .buffer_unordered(args.concurrency)
+                .collect::<Vec<_>>()
+                .await
+                .into_iter()
+                .flatten()
+                .collect::<Vec<_>>();
+
+            // Process results in order of completion
+            for (eval_name, result) in results {
+                match result {
+                    Ok((eval_output, judge_output)) => {
+                        println!("Generated diff for {eval_name}:\n");
+                        println!("{}\n", eval_output.diff);
+                        println!("Last message for {eval_name}:\n");
+                        println!("{}\n", eval_output.last_message);
+                        println!("Elapsed time: {:?}", eval_output.elapsed_time);
+                        println!(
+                            "Assistant response count: {}",
+                            eval_output.assistant_response_count
+                        );
+                        println!("Tool use counts: {:?}", eval_output.tool_use_counts);
+                        println!("Judge output for {eval_name}: {judge_output}");
+                    }
+                    Err(err) => {
+                        // TODO: Persist errors / surface errors at the end.
+                        println!("Error running {eval_name}: {err}");
+                    }
+                }
+            }
+
+            cx.update(|cx| cx.quit()).unwrap();
+        })
+        .detach();
+    });
+
+    println!("Done running evals");
+}
+
+async fn run_eval(
+    eval: Eval,
+    model: Arc<dyn LanguageModel>,
+    judge_model: Arc<dyn LanguageModel>,
+    app_state: Arc<HeadlessAppState>,
+    cx: AsyncApp,
+) -> anyhow::Result<(EvalOutput, String)> {
+    let path = eval.path.clone();
+    let judge = Judge::load(&path, judge_model).await?;
+    let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
+    let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
+    eval_output.save_to_directory(&path, judge_output.to_string())?;
+    Ok((eval_output, judge_output))
+}