Detailed changes
@@ -566,6 +566,41 @@ dependencies = [
"workspace",
]
+[[package]]
+name = "assistant_eval"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "assistant2",
+ "assistant_tool",
+ "assistant_tools",
+ "clap",
+ "client",
+ "collections",
+ "context_server",
+ "env_logger 0.11.7",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "gpui_tokio",
+ "itertools 0.14.0",
+ "language",
+ "language_model",
+ "language_models",
+ "node_runtime",
+ "project",
+ "prompt_store",
+ "regex",
+ "release_channel",
+ "reqwest_client",
+ "serde",
+ "serde_json",
+ "serde_json_lenient",
+ "settings",
+ "smol",
+ "util",
+]
+
[[package]]
name = "assistant_settings"
version = "0.1.0"
@@ -8,6 +8,7 @@ members = [
"crates/assistant",
"crates/assistant2",
"crates/assistant_context_editor",
+ "crates/assistant_eval",
"crates/assistant_settings",
"crates/assistant_slash_command",
"crates/assistant_slash_commands",
@@ -206,6 +207,7 @@ assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant2 = { path = "crates/assistant2" }
assistant_context_editor = { path = "crates/assistant_context_editor" }
+assistant_eval = { path = "crates/assistant_eval" }
assistant_settings = { path = "crates/assistant_settings" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
@@ -298,6 +298,7 @@ impl ActiveThread {
ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => {
self.save_thread(cx);
}
+ ThreadEvent::DoneStreaming => {}
ThreadEvent::StreamedAssistantText(message_id, text) => {
if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
markdown.update(cx, |markdown, cx| {
@@ -31,8 +31,11 @@ use gpui::{actions, App};
use prompt_store::PromptBuilder;
use settings::Settings as _;
+pub use crate::active_thread::ActiveThread;
pub use crate::assistant_panel::{AssistantPanel, ConcreteAssistantPanelDelegate};
pub use crate::inline_assistant::InlineAssistant;
+pub use crate::thread::{Message, RequestKind, Thread, ThreadEvent};
+pub use crate::thread_store::ThreadStore;
actions!(
assistant2,
@@ -284,6 +284,10 @@ impl Thread {
self.tool_use.tool_results_for_message(id)
}
+ pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
+ self.tool_use.tool_result(id)
+ }
+
pub fn scripting_tool_results_for_message(
&self,
id: MessageId,
@@ -652,32 +656,37 @@ impl Thread {
let result = stream_completion.await;
thread
- .update(&mut cx, |thread, cx| match result.as_ref() {
- Ok(stop_reason) => match stop_reason {
- StopReason::ToolUse => {
- cx.emit(ThreadEvent::UsePendingTools);
- }
- StopReason::EndTurn => {}
- StopReason::MaxTokens => {}
- },
- Err(error) => {
- if error.is::<PaymentRequiredError>() {
- cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
- } else if error.is::<MaxMonthlySpendReachedError>() {
- cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
- } else {
- let error_message = error
- .chain()
- .map(|err| err.to_string())
- .collect::<Vec<_>>()
- .join("\n");
- cx.emit(ThreadEvent::ShowError(ThreadError::Message(
- SharedString::from(error_message.clone()),
- )));
- }
+ .update(&mut cx, |thread, cx| {
+ match result.as_ref() {
+ Ok(stop_reason) => match stop_reason {
+ StopReason::ToolUse => {
+ cx.emit(ThreadEvent::UsePendingTools);
+ }
+ StopReason::EndTurn => {}
+ StopReason::MaxTokens => {}
+ },
+ Err(error) => {
+ if error.is::<PaymentRequiredError>() {
+ cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
+ } else if error.is::<MaxMonthlySpendReachedError>() {
+ cx.emit(ThreadEvent::ShowError(
+ ThreadError::MaxMonthlySpendReached,
+ ));
+ } else {
+ let error_message = error
+ .chain()
+ .map(|err| err.to_string())
+ .collect::<Vec<_>>()
+ .join("\n");
+ cx.emit(ThreadEvent::ShowError(ThreadError::Message(
+ SharedString::from(error_message.clone()),
+ )));
+ }
- thread.cancel_last_completion();
+ thread.cancel_last_completion();
+ }
}
+ cx.emit(ThreadEvent::DoneStreaming);
})
.ok();
});
@@ -1094,6 +1103,7 @@ pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
StreamedAssistantText(MessageId, String),
+ DoneStreaming,
MessageAdded(MessageId),
MessageEdited(MessageId),
MessageDeleted(MessageId),
@@ -182,6 +182,13 @@ impl ToolUseState {
.map_or(false, |results| !results.is_empty())
}
+ pub fn tool_result(
+ &self,
+ tool_use_id: &LanguageModelToolUseId,
+ ) -> Option<&LanguageModelToolResult> {
+ self.tool_results.get(tool_use_id)
+ }
+
pub fn request_tool_use(
&mut self,
assistant_message_id: MessageId,
@@ -0,0 +1,44 @@
+[package]
+name = "assistant_eval"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[[bin]]
+name = "assistant_eval"
+path = "src/main.rs"
+
+[dependencies]
+anyhow.workspace = true
+assistant2.workspace = true
+assistant_tool.workspace = true
+assistant_tools.workspace = true
+clap.workspace = true
+client.workspace = true
+collections.workspace = true
+context_server.workspace = true
+env_logger.workspace = true
+fs.workspace = true
+futures.workspace = true
+gpui.workspace = true
+gpui_tokio.workspace = true
+itertools.workspace = true
+language.workspace = true
+language_model.workspace = true
+language_models.workspace = true
+node_runtime.workspace = true
+project.workspace = true
+prompt_store.workspace = true
+regex.workspace = true
+release_channel.workspace = true
+reqwest_client.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+serde_json_lenient.workspace = true
+settings.workspace = true
+smol.workspace = true
+util.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,77 @@
+# Tool Evals
+
+A framework for evaluating and benchmarking AI assistant performance in the Zed editor.
+
+## Overview
+
+Tool Evals provides a headless environment for running assistants evaluations on code repositories. It automates the process of:
+
+1. Cloning and setting up test repositories
+2. Sending prompts to language models
+3. Allowing the assistant to use tools to modify code
+4. Collecting metrics on performance
+5. Evaluating results against known good solutions
+
+## How It Works
+
+The system consists of several key components:
+
+- **Eval**: Loads test cases from the evaluation_data directory, clones repos, and executes evaluations
+- **HeadlessAssistant**: Provides a headless environment for running the AI assistant
+- **Judge**: Compares AI-generated diffs with reference solutions and scores their functional similarity
+
+The evaluation flow:
+1. An evaluation is loaded from the evaluation_data directory
+2. The target repository is cloned and checked out at a specific commit
+3. A HeadlessAssistant instance is created with the specified language model
+4. The user prompt is sent to the assistant
+5. The assistant responds and uses tools to modify code
+6. Upon completion, a diff is generated from the changes
+7. Results are saved including the diff, assistant's response, and performance metrics
+8. If a reference solution exists, a Judge evaluates the similarity of the solution
+
+## Setup Requirements
+
+### Prerequisites
+
+- Rust and Cargo
+- Git
+- Network access to clone repositories
+- Appropriate API keys for language models and git services (Anthropic, GitHub, etc.)
+
+### Environment Variables
+
+Ensure you have the required API keys set, either from a dev run of Zed or via these environment variables:
+- `ZED_ANTHROPIC_API_KEY` for Claude models
+- `ZED_OPENAI_API_KEY` for OpenAI models
+- `ZED_GITHUB_API_KEY` for GitHub API (or similar)
+
+## Usage
+
+### Running a Single Evaluation
+
+To run a specific evaluation:
+
+```bash
+cargo run -p assistant_eval -- bubbletea-add-set-window-title
+```
+
+The arguments are regex patterns for the evaluation names to run, so to run all evaluations that contain `bubbletea`, run:
+
+```bash
+cargo run -p assistant_eval -- bubbletea
+```
+
+To run all evaluations:
+
+```bash
+cargo run -p assistant_eval -- --all
+```
+
+## Evaluation Data Structure
+
+Each evaluation should be placed in the `evaluation_data` directory with the following structure:
+
+* `prompt.txt`: The user's prompt.
+* `original.diff`: The `git diff` of the change anticipated for this prompt.
+* `setup.json`: Information about the repo used for the evaluation.
@@ -0,0 +1,61 @@
+// Copied from `crates/zed/build.rs`, with removal of code for including the zed icon on windows.
+
+use std::process::Command;
+
+fn main() {
+ if cfg!(target_os = "macos") {
+ println!("cargo:rustc-env=MACOSX_DEPLOYMENT_TARGET=10.15.7");
+
+ println!("cargo:rerun-if-env-changed=ZED_BUNDLE");
+ if std::env::var("ZED_BUNDLE").ok().as_deref() == Some("true") {
+ // Find WebRTC.framework in the Frameworks folder when running as part of an application bundle.
+ println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path/../Frameworks");
+ } else {
+ // Find WebRTC.framework as a sibling of the executable when running outside of an application bundle.
+ println!("cargo:rustc-link-arg=-Wl,-rpath,@executable_path");
+ }
+
+ // Weakly link ReplayKit to ensure Zed can be used on macOS 10.15+.
+ println!("cargo:rustc-link-arg=-Wl,-weak_framework,ReplayKit");
+
+ // Seems to be required to enable Swift concurrency
+ println!("cargo:rustc-link-arg=-Wl,-rpath,/usr/lib/swift");
+
+ // Register exported Objective-C selectors, protocols, etc
+ println!("cargo:rustc-link-arg=-Wl,-ObjC");
+ }
+
+ // Populate git sha environment variable if git is available
+ println!("cargo:rerun-if-changed=../../.git/logs/HEAD");
+ println!(
+ "cargo:rustc-env=TARGET={}",
+ std::env::var("TARGET").unwrap()
+ );
+ if let Ok(output) = Command::new("git").args(["rev-parse", "HEAD"]).output() {
+ if output.status.success() {
+ let git_sha = String::from_utf8_lossy(&output.stdout);
+ let git_sha = git_sha.trim();
+
+ println!("cargo:rustc-env=ZED_COMMIT_SHA={git_sha}");
+
+ if let Ok(build_profile) = std::env::var("PROFILE") {
+ if build_profile == "release" {
+ // This is currently the best way to make `cargo build ...`'s build script
+ // to print something to stdout without extra verbosity.
+ println!(
+ "cargo:warning=Info: using '{git_sha}' hash for ZED_COMMIT_SHA env var"
+ );
+ }
+ }
+ }
+ }
+
+ #[cfg(target_os = "windows")]
+ {
+ #[cfg(target_env = "msvc")]
+ {
+ // todo(windows): This is to avoid stack overflow. Remove it when solved.
+ println!("cargo:rustc-link-arg=/stack:{}", 8 * 1024 * 1024);
+ }
+ }
+}
@@ -0,0 +1,252 @@
+use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
+use anyhow::anyhow;
+use assistant2::RequestKind;
+use collections::HashMap;
+use gpui::{App, Task};
+use language_model::{LanguageModel, TokenUsage};
+use serde::{Deserialize, Serialize};
+use std::{
+ fs,
+ io::Write,
+ path::{Path, PathBuf},
+ sync::Arc,
+ time::Duration,
+};
+use util::command::new_smol_command;
+
+pub struct Eval {
+ pub name: String,
+ pub path: PathBuf,
+ pub repo_path: PathBuf,
+ pub eval_setup: EvalSetup,
+ pub user_prompt: String,
+}
+
+#[derive(Debug, Serialize)]
+pub struct EvalOutput {
+ pub diff: String,
+ pub last_message: String,
+ pub elapsed_time: Duration,
+ pub assistant_response_count: usize,
+ pub tool_use_counts: HashMap<Arc<str>, u32>,
+ pub token_usage: TokenUsage,
+}
+
+#[derive(Deserialize)]
+pub struct EvalSetup {
+ pub url: String,
+ pub base_sha: String,
+}
+
+impl Eval {
+ /// Loads the eval from a path (typically in `evaluation_data`). Clones and checks out the repo
+ /// if necessary.
+ pub async fn load(name: String, path: PathBuf, repos_dir: &Path) -> anyhow::Result<Self> {
+ let prompt_path = path.join("prompt.txt");
+ let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
+ let setup_path = path.join("setup.json");
+ let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
+ let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
+ let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
+ Ok(Eval {
+ name,
+ path,
+ repo_path,
+ eval_setup,
+ user_prompt,
+ })
+ }
+
+ pub fn run(
+ self,
+ app_state: Arc<HeadlessAppState>,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut App,
+ ) -> Task<anyhow::Result<EvalOutput>> {
+ cx.spawn(move |mut cx| async move {
+ checkout_repo(&self.eval_setup, &self.repo_path).await?;
+
+ let (assistant, done_rx) =
+ cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
+
+ let _worktree = assistant
+ .update(&mut cx, |assistant, cx| {
+ assistant.project.update(cx, |project, cx| {
+ project.create_worktree(&self.repo_path, true, cx)
+ })
+ })?
+ .await?;
+
+ let start_time = std::time::SystemTime::now();
+
+ assistant.update(&mut cx, |assistant, cx| {
+ assistant.thread.update(cx, |thread, cx| {
+ let context = vec![];
+ thread.insert_user_message(self.user_prompt.clone(), context, cx);
+ thread.send_to_model(model, RequestKind::Chat, cx);
+ });
+ })?;
+
+ done_rx.recv().await??;
+
+ let elapsed_time = start_time.elapsed()?;
+
+ let diff = query_git(&self.repo_path, vec!["diff"]).await?;
+
+ assistant.update(&mut cx, |assistant, cx| {
+ let thread = assistant.thread.read(cx);
+ let last_message = thread.messages().last().unwrap();
+ if last_message.role != language_model::Role::Assistant {
+ return Err(anyhow!("Last message is not from assistant"));
+ }
+ let assistant_response_count = thread
+ .messages()
+ .filter(|message| message.role == language_model::Role::Assistant)
+ .count();
+ Ok(EvalOutput {
+ diff,
+ last_message: last_message.text.clone(),
+ elapsed_time,
+ assistant_response_count,
+ tool_use_counts: assistant.tool_use_counts.clone(),
+ token_usage: thread.cumulative_token_usage(),
+ })
+ })?
+ })
+ }
+}
+
+impl EvalOutput {
+ // Method to save the output to a directory
+ pub fn save_to_directory(
+ &self,
+ output_dir: &Path,
+ eval_output_value: String,
+ ) -> anyhow::Result<()> {
+ // Create the output directory if it doesn't exist
+ fs::create_dir_all(&output_dir)?;
+
+ // Save the diff to a file
+ let diff_path = output_dir.join("diff.patch");
+ let mut diff_file = fs::File::create(&diff_path)?;
+ diff_file.write_all(self.diff.as_bytes())?;
+
+ // Save the last message to a file
+ let message_path = output_dir.join("assistant_response.txt");
+ let mut message_file = fs::File::create(&message_path)?;
+ message_file.write_all(self.last_message.as_bytes())?;
+
+ // Current metrics for this run
+ let current_metrics = serde_json::json!({
+ "elapsed_time_ms": self.elapsed_time.as_millis(),
+ "assistant_response_count": self.assistant_response_count,
+ "tool_use_counts": self.tool_use_counts,
+ "token_usage": self.token_usage,
+ "eval_output_value": eval_output_value,
+ });
+
+ // Get current timestamp in milliseconds
+ let timestamp = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)?
+ .as_millis()
+ .to_string();
+
+ // Path to metrics file
+ let metrics_path = output_dir.join("metrics.json");
+
+ // Load existing metrics if the file exists, or create a new object
+ let mut historical_metrics = if metrics_path.exists() {
+ let metrics_content = fs::read_to_string(&metrics_path)?;
+ serde_json::from_str::<serde_json::Value>(&metrics_content)
+ .unwrap_or_else(|_| serde_json::json!({}))
+ } else {
+ serde_json::json!({})
+ };
+
+ // Add new run with timestamp as key
+ if let serde_json::Value::Object(ref mut map) = historical_metrics {
+ map.insert(timestamp, current_metrics);
+ }
+
+ // Write updated metrics back to file
+ let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
+ let mut metrics_file = fs::File::create(&metrics_path)?;
+ metrics_file.write_all(metrics_json.as_bytes())?;
+
+ Ok(())
+ }
+}
+
+fn repo_dir_name(url: &str) -> String {
+ url.trim_start_matches("https://")
+ .replace(|c: char| !c.is_alphanumeric(), "_")
+}
+
+async fn checkout_repo(eval_setup: &EvalSetup, repo_path: &Path) -> anyhow::Result<()> {
+ if !repo_path.exists() {
+ smol::unblock({
+ let repo_path = repo_path.to_path_buf();
+ || std::fs::create_dir_all(repo_path)
+ })
+ .await?;
+ run_git(repo_path, vec!["init"]).await?;
+ run_git(repo_path, vec!["remote", "add", "origin", &eval_setup.url]).await?;
+ } else {
+ let actual_origin = query_git(repo_path, vec!["remote", "get-url", "origin"]).await?;
+ if actual_origin != eval_setup.url {
+ return Err(anyhow!(
+ "remote origin {} does not match expected origin {}",
+ actual_origin,
+ eval_setup.url
+ ));
+ }
+
+ // TODO: consider including "-x" to remove ignored files. The downside of this is that it will
+ // also remove build artifacts, and so prevent incremental reuse there.
+ run_git(repo_path, vec!["clean", "--force", "-d"]).await?;
+ run_git(repo_path, vec!["reset", "--hard", "HEAD"]).await?;
+ }
+
+ run_git(
+ repo_path,
+ vec!["fetch", "--depth", "1", "origin", &eval_setup.base_sha],
+ )
+ .await?;
+ run_git(repo_path, vec!["checkout", &eval_setup.base_sha]).await?;
+
+ Ok(())
+}
+
+async fn run_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<()> {
+ let exit_status = new_smol_command("git")
+ .current_dir(repo_path)
+ .args(args.clone())
+ .status()
+ .await?;
+ if exit_status.success() {
+ Ok(())
+ } else {
+ Err(anyhow!(
+ "`git {}` failed with {}",
+ args.join(" "),
+ exit_status,
+ ))
+ }
+}
+
+async fn query_git(repo_path: &Path, args: Vec<&str>) -> anyhow::Result<String> {
+ let output = new_smol_command("git")
+ .current_dir(repo_path)
+ .args(args.clone())
+ .output()
+ .await?;
+ if output.status.success() {
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+ } else {
+ Err(anyhow!(
+ "`git {}` failed with {}",
+ args.join(" "),
+ output.status
+ ))
+ }
+}
@@ -0,0 +1,241 @@
+use anyhow::anyhow;
+use assistant2::{Thread, ThreadEvent, ThreadStore};
+use assistant_tool::ToolWorkingSet;
+use client::{Client, UserStore};
+use collections::HashMap;
+use futures::StreamExt;
+use gpui::{prelude::*, App, AsyncApp, Entity, SemanticVersion, Subscription, Task};
+use language::LanguageRegistry;
+use language_model::{
+ AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+ LanguageModelRequest,
+};
+use node_runtime::NodeRuntime;
+use project::{Project, RealFs};
+use prompt_store::PromptBuilder;
+use settings::SettingsStore;
+use smol::channel;
+use std::sync::Arc;
+
+/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
+pub struct HeadlessAppState {
+ pub languages: Arc<LanguageRegistry>,
+ pub client: Arc<Client>,
+ pub user_store: Entity<UserStore>,
+ pub fs: Arc<dyn fs::Fs>,
+ pub node_runtime: NodeRuntime,
+
+ // Additional fields not present in `workspace::AppState`.
+ pub prompt_builder: Arc<PromptBuilder>,
+}
+
+pub struct HeadlessAssistant {
+ pub thread: Entity<Thread>,
+ pub project: Entity<Project>,
+ #[allow(dead_code)]
+ pub thread_store: Entity<ThreadStore>,
+ pub tool_use_counts: HashMap<Arc<str>, u32>,
+ pub done_tx: channel::Sender<anyhow::Result<()>>,
+ _subscription: Subscription,
+}
+
+impl HeadlessAssistant {
+ pub fn new(
+ app_state: Arc<HeadlessAppState>,
+ cx: &mut App,
+ ) -> anyhow::Result<(Entity<Self>, channel::Receiver<anyhow::Result<()>>)> {
+ let env = None;
+ let project = Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ env,
+ cx,
+ );
+
+ let tools = Arc::new(ToolWorkingSet::default());
+ let thread_store =
+ ThreadStore::new(project.clone(), tools, app_state.prompt_builder.clone(), cx)?;
+
+ let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx));
+
+ let (done_tx, done_rx) = channel::unbounded::<anyhow::Result<()>>();
+
+ let headless_thread = cx.new(move |cx| Self {
+ _subscription: cx.subscribe(&thread, Self::handle_thread_event),
+ thread,
+ project,
+ thread_store,
+ tool_use_counts: HashMap::default(),
+ done_tx,
+ });
+
+ Ok((headless_thread, done_rx))
+ }
+
+ fn handle_thread_event(
+ &mut self,
+ thread: Entity<Thread>,
+ event: &ThreadEvent,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ ThreadEvent::ShowError(err) => self
+ .done_tx
+ .send_blocking(Err(anyhow!("{:?}", err)))
+ .unwrap(),
+ ThreadEvent::DoneStreaming => {
+ let thread = thread.read(cx);
+ if let Some(message) = thread.messages().last() {
+ println!("Message: {}", message.text,);
+ }
+ if thread.all_tools_finished() {
+ self.done_tx.send_blocking(Ok(())).unwrap()
+ }
+ }
+ ThreadEvent::UsePendingTools => {
+ thread.update(cx, |thread, cx| {
+ thread.use_pending_tools(cx);
+ });
+ }
+ ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ } => {
+ if let Some(pending_tool_use) = pending_tool_use {
+ println!(
+ "Used tool {} with input: {}",
+ pending_tool_use.name, pending_tool_use.input
+ );
+ *self
+ .tool_use_counts
+ .entry(pending_tool_use.name.clone())
+ .or_insert(0) += 1;
+ }
+ if let Some(tool_result) = thread.read(cx).tool_result(tool_use_id) {
+ println!("Tool result: {:?}", tool_result);
+ }
+ if thread.read(cx).all_tools_finished() {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ if let Some(model) = model_registry.active_model() {
+ thread.update(cx, |thread, cx| {
+ // Currently evals do not support specifying context.
+ let updated_context = vec![];
+ thread.send_tool_results_to_model(model, updated_context, cx);
+ });
+ }
+ }
+ }
+ ThreadEvent::StreamedCompletion
+ | ThreadEvent::SummaryChanged
+ | ThreadEvent::StreamedAssistantText(_, _)
+ | ThreadEvent::MessageAdded(_)
+ | ThreadEvent::MessageEdited(_)
+ | ThreadEvent::MessageDeleted(_) => {}
+ }
+ }
+}
+
+pub fn init(cx: &mut App) -> Arc<HeadlessAppState> {
+ release_channel::init(SemanticVersion::default(), cx);
+ gpui_tokio::init(cx);
+
+ let mut settings_store = SettingsStore::new(cx);
+ settings_store
+ .set_default_settings(settings::default_settings().as_ref(), cx)
+ .unwrap();
+ cx.set_global(settings_store);
+ client::init_settings(cx);
+ Project::init_settings(cx);
+
+ let client = Client::production(cx);
+ cx.set_http_client(client.http_client().clone());
+
+ let git_binary_path = None;
+ let fs = Arc::new(RealFs::new(git_binary_path));
+
+ let languages = Arc::new(LanguageRegistry::new(cx.background_executor().clone()));
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+
+ language::init(cx);
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
+ assistant_tools::init(cx);
+ context_server::init(cx);
+ let stdout_is_a_pty = false;
+ let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
+ assistant2::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
+
+ Arc::new(HeadlessAppState {
+ languages,
+ client,
+ user_store,
+ fs,
+ node_runtime: NodeRuntime::unavailable(),
+ prompt_builder,
+ })
+}
+
+pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model = model_registry
+ .available_models(cx)
+ .find(|model| model.id().0 == model_name);
+
+ let Some(model) = model else {
+ return Err(anyhow!(
+ "No language model named {} was available. Available models: {}",
+ model_name,
+ model_registry
+ .available_models(cx)
+ .map(|model| model.id().0.clone())
+ .collect::<Vec<_>>()
+ .join(", ")
+ ));
+ };
+
+ Ok(model)
+}
+
+pub fn authenticate_model_provider(
+ provider_id: LanguageModelProviderId,
+ cx: &mut App,
+) -> Task<std::result::Result<(), AuthenticateError>> {
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let model_provider = model_registry.provider(&provider_id).unwrap();
+ model_provider.authenticate(cx)
+}
+
+pub async fn send_language_model_request(
+ model: Arc<dyn LanguageModel>,
+ request: LanguageModelRequest,
+ cx: AsyncApp,
+) -> anyhow::Result<String> {
+ match model.stream_completion_text(request, &cx).await {
+ Ok(mut stream) => {
+ let mut full_response = String::new();
+
+ // Process the response stream
+ while let Some(chunk_result) = stream.stream.next().await {
+ match chunk_result {
+ Ok(chunk_str) => {
+ full_response.push_str(&chunk_str);
+ }
+ Err(err) => {
+ return Err(anyhow!(
+ "Error receiving response from language model: {err}"
+ ));
+ }
+ }
+ }
+
+ Ok(full_response)
+ }
+ Err(err) => Err(anyhow!(
+ "Failed to get response from language model. Error was: {err}"
+ )),
+ }
+}
@@ -0,0 +1,121 @@
+use crate::eval::EvalOutput;
+use crate::headless_assistant::send_language_model_request;
+use anyhow::anyhow;
+use gpui::{App, Task};
+use language_model::{
+ LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
+};
+use std::{path::Path, sync::Arc};
+
+pub struct Judge {
+ pub original_diff: Option<String>,
+ #[allow(dead_code)]
+ pub original_message: Option<String>,
+ pub model: Arc<dyn LanguageModel>,
+}
+
+impl Judge {
+ pub async fn load(eval_path: &Path, model: Arc<dyn LanguageModel>) -> anyhow::Result<Judge> {
+ let original_diff_path = eval_path.join("original.diff");
+ let original_diff = smol::unblock(move || {
+ if std::fs::exists(&original_diff_path)? {
+ anyhow::Ok(Some(std::fs::read_to_string(&original_diff_path)?))
+ } else {
+ anyhow::Ok(None)
+ }
+ });
+
+ let original_message_path = eval_path.join("original_message.txt");
+ let original_message = smol::unblock(move || {
+ if std::fs::exists(&original_message_path)? {
+ anyhow::Ok(Some(std::fs::read_to_string(&original_message_path)?))
+ } else {
+ anyhow::Ok(None)
+ }
+ });
+
+ Ok(Self {
+ original_diff: original_diff.await?,
+ original_message: original_message.await?,
+ model,
+ })
+ }
+
+ pub fn run(&self, eval_output: &EvalOutput, cx: &mut App) -> Task<anyhow::Result<String>> {
+ let Some(original_diff) = self.original_diff.as_ref() else {
+ return Task::ready(Err(anyhow!("No original.diff found")));
+ };
+
+ // TODO: check for empty diff?
+ let prompt = diff_comparison_prompt(&original_diff, &eval_output.diff);
+
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text(prompt)],
+ cache: false,
+ }],
+ temperature: Some(0.0),
+ tools: Vec::new(),
+ stop: Vec::new(),
+ };
+
+ let model = self.model.clone();
+ cx.spawn(move |cx| send_language_model_request(model, request, cx))
+ }
+}
+
+pub fn diff_comparison_prompt(original_diff: &str, new_diff: &str) -> String {
+ format!(
+ r#"# Git Diff Similarity Evaluation Template
+
+## Instructions
+
+Compare the two diffs and score them between 0.0 and 1.0 based on their functional similarity.
+- 1.0 = Perfect functional match (achieves identical results)
+- 0.0 = No functional similarity whatsoever
+
+## Evaluation Criteria
+
+Please consider the following aspects in order of importance:
+
+1. **Functional Equivalence (60%)**
+ - Do both diffs achieve the same end result?
+ - Are the changes functionally equivalent despite possibly using different approaches?
+ - Do the modifications address the same issues or implement the same features?
+
+2. **Logical Structure (20%)**
+ - Are the logical flows similar?
+ - Do the modifications affect the same code paths?
+ - Are control structures (if/else, loops, etc.) modified in similar ways?
+
+3. **Code Content (15%)**
+ - Are similar lines added/removed?
+ - Are the same variables, functions, or methods being modified?
+ - Are the same APIs or libraries being used?
+
+4. **File Layout (5%)**
+ - Are the same files being modified?
+ - Are changes occurring in similar locations within files?
+
+## Input
+
+Original Diff:
+```git
+{}
+```
+
+New Diff:
+```git
+{}
+```
+
+## Output Format
+
+THE ONLY OUTPUT SHOULD BE A SCORE BETWEEN 0.0 AND 1.0.
+
+Example output:
+0.85"#,
+ original_diff, new_diff
+ )
+}
@@ -0,0 +1,234 @@
+mod eval;
+mod headless_assistant;
+mod judge;
+
+use clap::Parser;
+use eval::{Eval, EvalOutput};
+use futures::{stream, StreamExt};
+use gpui::{Application, AsyncApp};
+use headless_assistant::{authenticate_model_provider, find_model, HeadlessAppState};
+use itertools::Itertools;
+use judge::Judge;
+use language_model::{LanguageModel, LanguageModelRegistry};
+use regex::Regex;
+use reqwest_client::ReqwestClient;
+use std::{cmp, path::PathBuf, sync::Arc};
+
+#[derive(Parser, Debug)]
+#[command(
+ name = "assistant_eval",
+ disable_version_flag = true,
+ before_help = "Tool eval runner"
+)]
+struct Args {
+ /// Regexes to match the names of evals to run.
+ eval_name_regexes: Vec<String>,
+ /// Runs all evals in `evaluation_data`, causes the regex to be ignored.
+ #[arg(long)]
+ all: bool,
+ /// Name of the model (default: "claude-3-7-sonnet-latest")
+ #[arg(long, default_value = "claude-3-7-sonnet-latest")]
+ model_name: String,
+ /// Name of the editor model (default: value of `--model_name`).
+ #[arg(long)]
+ editor_model_name: Option<String>,
+ /// Name of the judge model (default: value of `--model_name`).
+ #[arg(long)]
+ judge_model_name: Option<String>,
+ /// Number of evaluations to run concurrently (default: 10)
+ #[arg(short, long, default_value = "10")]
+ concurrency: usize,
+}
+
+fn main() {
+ env_logger::init();
+ let args = Args::parse();
+ let http_client = Arc::new(ReqwestClient::new());
+ let app = Application::headless().with_http_client(http_client.clone());
+
+ let crate_dir = PathBuf::from("../zed-agent-bench");
+ let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
+ let repos_dir = crate_dir.join("repos").canonicalize().unwrap();
+
+ let all_evals = std::fs::read_dir(&evaluation_data_dir)
+ .unwrap()
+ .map(|path| path.unwrap().file_name().to_string_lossy().to_string())
+ .collect::<Vec<_>>();
+
+ let evals_to_run = if args.all {
+ all_evals
+ } else {
+ args.eval_name_regexes
+ .into_iter()
+ .map(|regex_string| Regex::new(®ex_string).unwrap())
+ .flat_map(|regex| {
+ all_evals
+ .iter()
+ .filter(|eval_name| regex.is_match(eval_name))
+ .cloned()
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>()
+ };
+
+ if evals_to_run.is_empty() {
+ panic!("Names of evals to run must be provided or `--all` specified");
+ }
+
+ println!("Will run the following evals: {evals_to_run:?}");
+ println!("Running up to {} evals concurrently", args.concurrency);
+
+ let editor_model_name = if let Some(model_name) = args.editor_model_name {
+ model_name
+ } else {
+ args.model_name.clone()
+ };
+
+ let judge_model_name = if let Some(model_name) = args.judge_model_name {
+ model_name
+ } else {
+ args.model_name.clone()
+ };
+
+ app.run(move |cx| {
+ let app_state = headless_assistant::init(cx);
+
+ let model = find_model(&args.model_name, cx).unwrap();
+ let editor_model = find_model(&editor_model_name, cx).unwrap();
+ let judge_model = find_model(&judge_model_name, cx).unwrap();
+
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry.set_active_model(Some(model.clone()), cx);
+ registry.set_editor_model(Some(editor_model.clone()), cx);
+ });
+
+ let model_provider_id = model.provider_id();
+ let editor_model_provider_id = editor_model.provider_id();
+ let judge_model_provider_id = judge_model.provider_id();
+
+ cx.spawn(move |cx| async move {
+ // Authenticate all model providers first
+ cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
+ .unwrap()
+ .await
+ .unwrap();
+ cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
+ .unwrap()
+ .await
+ .unwrap();
+ cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
+ .unwrap()
+ .await
+ .unwrap();
+
+ let loaded_evals = stream::iter(evals_to_run)
+ .map(|eval_name| {
+ let eval_path = evaluation_data_dir.join(&eval_name);
+ let repos_dir = repos_dir.clone();
+ async move {
+ match Eval::load(eval_name.clone(), eval_path, &repos_dir).await {
+ Ok(eval) => Some(eval),
+ Err(err) => {
+ // TODO: Persist errors / surface errors at the end.
+ println!("Error loading {eval_name}: {err}");
+ None
+ }
+ }
+ }
+ })
+ .buffer_unordered(args.concurrency)
+ .collect::<Vec<_>>()
+ .await
+ .into_iter()
+ .flatten()
+ .collect::<Vec<_>>();
+
+ // The evals need to be loaded and grouped by URL before concurrently running, since
+ // evals that use the same remote URL will use the same working directory.
+ let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
+ .into_iter()
+ .map(|eval| (eval.eval_setup.url.clone(), eval))
+ .into_group_map()
+ .into_values()
+ .collect::<Vec<_>>();
+
+ // Sort groups in descending order, so that bigger groups start first.
+ evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
+
+ let results = stream::iter(evals_grouped_by_url)
+ .map(|evals| {
+ let model = model.clone();
+ let judge_model = judge_model.clone();
+ let app_state = app_state.clone();
+ let cx = cx.clone();
+
+ async move {
+ let mut results = Vec::new();
+ for eval in evals {
+ let name = eval.name.clone();
+ println!("Starting eval named {}", name);
+ let result = run_eval(
+ eval,
+ model.clone(),
+ judge_model.clone(),
+ app_state.clone(),
+ cx.clone(),
+ )
+ .await;
+ results.push((name, result));
+ }
+ results
+ }
+ })
+ .buffer_unordered(args.concurrency)
+ .collect::<Vec<_>>()
+ .await
+ .into_iter()
+ .flatten()
+ .collect::<Vec<_>>();
+
+ // Process results in order of completion
+ for (eval_name, result) in results {
+ match result {
+ Ok((eval_output, judge_output)) => {
+ println!("Generated diff for {eval_name}:\n");
+ println!("{}\n", eval_output.diff);
+ println!("Last message for {eval_name}:\n");
+ println!("{}\n", eval_output.last_message);
+ println!("Elapsed time: {:?}", eval_output.elapsed_time);
+ println!(
+ "Assistant response count: {}",
+ eval_output.assistant_response_count
+ );
+ println!("Tool use counts: {:?}", eval_output.tool_use_counts);
+ println!("Judge output for {eval_name}: {judge_output}");
+ }
+ Err(err) => {
+ // TODO: Persist errors / surface errors at the end.
+ println!("Error running {eval_name}: {err}");
+ }
+ }
+ }
+
+ cx.update(|cx| cx.quit()).unwrap();
+ })
+ .detach();
+ });
+
+ println!("Done running evals");
+}
+
+async fn run_eval(
+ eval: Eval,
+ model: Arc<dyn LanguageModel>,
+ judge_model: Arc<dyn LanguageModel>,
+ app_state: Arc<HeadlessAppState>,
+ cx: AsyncApp,
+) -> anyhow::Result<(EvalOutput, String)> {
+ let path = eval.path.clone();
+ let judge = Judge::load(&path, judge_model).await?;
+ let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
+ let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
+ eval_output.save_to_directory(&path, judge_output.to_string())?;
+ Ok((eval_output, judge_output))
+}