1mod eval;
2mod headless_assistant;
3mod judge;
4
5use clap::Parser;
6use eval::{Eval, EvalOutput};
7use futures::future;
8use gpui::{Application, AsyncApp};
9use headless_assistant::{HeadlessAppState, authenticate_model_provider, find_model};
10use itertools::Itertools;
11use judge::Judge;
12use language_model::{LanguageModel, LanguageModelRegistry};
13use regex::Regex;
14use reqwest_client::ReqwestClient;
15use std::{cmp, path::PathBuf, sync::Arc};
16
17#[derive(Parser, Debug)]
18#[command(
19 name = "assistant_eval",
20 disable_version_flag = true,
21 before_help = "Tool eval runner"
22)]
23struct Args {
24 /// Regexes to match the names of evals to run.
25 eval_name_regexes: Vec<String>,
26 /// Runs all evals in `evaluation_data`, causes the regex to be ignored.
27 #[arg(long)]
28 all: bool,
29 /// Name of the model (default: "claude-3-7-sonnet-latest")
30 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
31 model_name: String,
32 /// Name of the editor model (default: value of `--model_name`).
33 #[arg(long)]
34 editor_model_name: Option<String>,
35 /// Name of the judge model (default: value of `--model_name`).
36 #[arg(long)]
37 judge_model_name: Option<String>,
38 /// Number of evaluations to run concurrently (default: 10)
39 #[arg(short, long, default_value = "10")]
40 concurrency: usize,
41}
42
43fn main() {
44 env_logger::init();
45 let args = Args::parse();
46 let http_client = Arc::new(ReqwestClient::new());
47 let app = Application::headless().with_http_client(http_client.clone());
48
49 let crate_dir = PathBuf::from("../zed-agent-bench");
50 let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
51
52 let repos_dir = crate_dir.join("repos");
53 if !repos_dir.exists() {
54 std::fs::create_dir_all(&repos_dir).unwrap();
55 }
56 let repos_dir = repos_dir.canonicalize().unwrap();
57
58 let all_evals = std::fs::read_dir(&evaluation_data_dir)
59 .unwrap()
60 .map(|path| path.unwrap().file_name().to_string_lossy().to_string())
61 .collect::<Vec<_>>();
62
63 let evals_to_run = if args.all {
64 all_evals
65 } else {
66 args.eval_name_regexes
67 .into_iter()
68 .map(|regex_string| Regex::new(®ex_string).unwrap())
69 .flat_map(|regex| {
70 all_evals
71 .iter()
72 .filter(|eval_name| regex.is_match(eval_name))
73 .cloned()
74 .collect::<Vec<_>>()
75 })
76 .collect::<Vec<_>>()
77 };
78
79 if evals_to_run.is_empty() {
80 panic!("Names of evals to run must be provided or `--all` specified");
81 }
82
83 println!("Will run the following evals: {evals_to_run:?}");
84 println!("Running up to {} evals concurrently", args.concurrency);
85
86 let editor_model_name = if let Some(model_name) = args.editor_model_name {
87 model_name
88 } else {
89 args.model_name.clone()
90 };
91
92 let judge_model_name = if let Some(model_name) = args.judge_model_name {
93 model_name
94 } else {
95 args.model_name.clone()
96 };
97
98 app.run(move |cx| {
99 let app_state = headless_assistant::init(cx);
100
101 let model = find_model(&args.model_name, cx).unwrap();
102 let editor_model = find_model(&editor_model_name, cx).unwrap();
103 let judge_model = find_model(&judge_model_name, cx).unwrap();
104
105 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
106 registry.set_active_model(Some(model.clone()), cx);
107 registry.set_editor_model(Some(editor_model.clone()), cx);
108 });
109
110 let model_provider_id = model.provider_id();
111 let editor_model_provider_id = editor_model.provider_id();
112 let judge_model_provider_id = judge_model.provider_id();
113
114 cx.spawn(async move |cx| {
115 // Authenticate all model providers first
116 cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
117 .unwrap()
118 .await
119 .unwrap();
120 cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
121 .unwrap()
122 .await
123 .unwrap();
124 cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
125 .unwrap()
126 .await
127 .unwrap();
128
129 let eval_load_futures = evals_to_run
130 .into_iter()
131 .map(|eval_name| {
132 let eval_path = evaluation_data_dir.join(&eval_name);
133 let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
134 async move {
135 match load_future.await {
136 Ok(eval) => Some(eval),
137 Err(err) => {
138 // TODO: Persist errors / surface errors at the end.
139 println!("Error loading {eval_name}: {err}");
140 None
141 }
142 }
143 }
144 })
145 .collect::<Vec<_>>();
146
147 let loaded_evals = future::join_all(eval_load_futures)
148 .await
149 .into_iter()
150 .flatten()
151 .collect::<Vec<_>>();
152
153 // The evals need to be loaded and grouped by URL before concurrently running, since
154 // evals that use the same remote URL will use the same working directory.
155 let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
156 .into_iter()
157 .map(|eval| (eval.eval_setup.url.clone(), eval))
158 .into_group_map()
159 .into_values()
160 .collect::<Vec<_>>();
161
162 // Sort groups in descending order, so that bigger groups start first.
163 evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
164
165 let result_futures = evals_grouped_by_url
166 .into_iter()
167 .map(|evals| {
168 let model = model.clone();
169 let judge_model = judge_model.clone();
170 let app_state = app_state.clone();
171 let cx = cx.clone();
172
173 async move {
174 let mut results = Vec::new();
175 for eval in evals {
176 let name = eval.name.clone();
177 println!("Starting eval named {}", name);
178 let result = run_eval(
179 eval,
180 model.clone(),
181 judge_model.clone(),
182 app_state.clone(),
183 cx.clone(),
184 )
185 .await;
186 results.push((name, result));
187 }
188 results
189 }
190 })
191 .collect::<Vec<_>>();
192
193 let results = future::join_all(result_futures)
194 .await
195 .into_iter()
196 .flatten()
197 .collect::<Vec<_>>();
198
199 // Process results in order of completion
200 for (eval_name, result) in results {
201 match result {
202 Ok((eval_output, judge_output)) => {
203 println!("Generated diff for {eval_name}:\n");
204 println!("{}\n", eval_output.diff);
205 println!("Last message for {eval_name}:\n");
206 println!("{}\n", eval_output.last_message);
207 println!("Elapsed time: {:?}", eval_output.elapsed_time);
208 println!(
209 "Assistant response count: {}",
210 eval_output.assistant_response_count
211 );
212 println!("Tool use counts: {:?}", eval_output.tool_use_counts);
213 println!("Judge output for {eval_name}: {judge_output}");
214 }
215 Err(err) => {
216 // TODO: Persist errors / surface errors at the end.
217 println!("Error running {eval_name}: {err}");
218 }
219 }
220 }
221
222 cx.update(|cx| cx.quit()).unwrap();
223 })
224 .detach();
225 });
226
227 println!("Done running evals");
228}
229
230async fn run_eval(
231 eval: Eval,
232 model: Arc<dyn LanguageModel>,
233 judge_model: Arc<dyn LanguageModel>,
234 app_state: Arc<HeadlessAppState>,
235 cx: AsyncApp,
236) -> anyhow::Result<(EvalOutput, String)> {
237 let path = eval.path.clone();
238 let judge = Judge::load(&path, judge_model).await?;
239 let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
240 let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
241 eval_output.save_to_directory(&path, judge_output.to_string())?;
242 Ok((eval_output, judge_output))
243}