1mod assertions;
2mod example;
3mod examples;
4mod explorer;
5mod ids;
6mod instance;
7mod tool_metrics;
8
9use assertions::{AssertionsReport, display_error_row};
10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
11pub(crate) use tool_metrics::*;
12
13use ::fs::RealFs;
14use anyhow::anyhow;
15use clap::Parser;
16use client::{Client, ProxySettings, UserStore};
17use collections::{HashMap, HashSet};
18use extension::ExtensionHostProxy;
19use futures::future;
20use gpui::http_client::read_proxy_from_env;
21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
22use gpui_tokio::Tokio;
23use language::LanguageRegistry;
24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
25use node_runtime::{NodeBinaryOptions, NodeRuntime};
26use project::Project;
27use project::project_settings::ProjectSettings;
28use prompt_store::PromptBuilder;
29use release_channel::AppVersion;
30use reqwest_client::ReqwestClient;
31use settings::{Settings, SettingsStore};
32use std::cell::RefCell;
33use std::collections::VecDeque;
34use std::env;
35use std::path::{Path, PathBuf};
36use std::rc::Rc;
37use std::sync::{Arc, LazyLock};
38use util::ResultExt as _;
39
40static CARGO_MANIFEST_DIR: LazyLock<PathBuf> =
41 LazyLock::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")));
42
43#[derive(Parser, Debug)]
44#[command(name = "eval", disable_version_flag = true)]
45struct Args {
46 /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
47 #[arg(value_name = "EXAMPLE_SUBSTRING")]
48 filter: Vec<String>,
49 /// ID of model to use.
50 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
51 model: String,
52 /// Model provider to use.
53 #[arg(long, default_value = "anthropic")]
54 provider: String,
55 #[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
56 languages: Vec<String>,
57 /// How many times to run each example.
58 #[arg(long, default_value = "8")]
59 repetitions: usize,
60 /// Maximum number of examples to run concurrently.
61 #[arg(long, default_value = "4")]
62 concurrency: usize,
63}
64
65fn main() {
66 dotenv::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
67
68 env_logger::init();
69
70 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
71 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
72 let session_id = uuid::Uuid::new_v4().to_string();
73 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
74 let run_id = match env::var("GITHUB_RUN_ID") {
75 Ok(run_id) => format!("github/{}", run_id),
76 Err(_) => format!("local/{}", run_timestamp),
77 };
78
79 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
80 .parent()
81 .unwrap()
82 .parent()
83 .unwrap()
84 .canonicalize()
85 .unwrap();
86 let eval_crate_dir = root_dir.join("crates").join("eval");
87 let repos_dir = eval_crate_dir.join("repos");
88 let worktrees_dir = eval_crate_dir.join("worktrees");
89 let examples_dir = eval_crate_dir.join("src").join("examples");
90 let run_dir = eval_crate_dir
91 .join("runs")
92 .join(format!("{}", run_timestamp));
93 std::fs::create_dir_all(&run_dir).unwrap();
94 std::fs::create_dir_all(&repos_dir).unwrap();
95 std::fs::create_dir_all(&worktrees_dir).unwrap();
96 std::fs::create_dir_all(&examples_dir).unwrap();
97 std::fs::create_dir_all(&paths::config_dir()).unwrap();
98
99 let zed_commit_sha = commit_sha_for_path(&root_dir);
100 let zed_branch_name = git_branch_for_path(&root_dir);
101 let args = Args::parse();
102 let languages: HashSet<String> = args.languages.into_iter().collect();
103
104 let http_client = Arc::new(ReqwestClient::new());
105 let app = Application::headless().with_http_client(http_client.clone());
106 let all_threads = examples::all(&examples_dir);
107
108 app.run(move |cx| {
109 let app_state = init(cx);
110
111 let telemetry = app_state.client.telemetry();
112 telemetry.start(system_id, installation_id, session_id, cx);
113
114 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
115 && telemetry.has_checksum_seed();
116 if enable_telemetry {
117 println!("Telemetry enabled");
118 telemetry::event!(
119 "Agent Eval Started",
120 zed_commit_sha = zed_commit_sha,
121 zed_branch_name = zed_branch_name,
122 run_id = run_id,
123 );
124 }
125
126 let mut cumulative_tool_metrics = ToolMetrics::default();
127
128 let model_registry = LanguageModelRegistry::read_global(cx);
129 let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
130 let model_provider_id = model.provider_id();
131 let model_provider = model_registry.provider(&model_provider_id).unwrap();
132
133 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
134 registry.set_default_model(
135 Some(ConfiguredModel {
136 provider: model_provider.clone(),
137 model: model.clone(),
138 }),
139 cx,
140 );
141 });
142
143 let authenticate_task = model_provider.authenticate(cx);
144
145 cx.spawn(async move |cx| {
146 authenticate_task.await.unwrap();
147
148 let mut examples = Vec::new();
149
150 const COLORS: [&str; 12] = [
151 "\x1b[31m", // Red
152 "\x1b[32m", // Green
153 "\x1b[33m", // Yellow
154 "\x1b[34m", // Blue
155 "\x1b[35m", // Magenta
156 "\x1b[36m", // Cyan
157 "\x1b[91m", // Bright Red
158 "\x1b[92m", // Bright Green
159 "\x1b[93m", // Bright Yellow
160 "\x1b[94m", // Bright Blue
161 "\x1b[95m", // Bright Magenta
162 "\x1b[96m", // Bright Cyan
163 ];
164
165 let mut skipped = Vec::new();
166
167 for thread in all_threads {
168 let meta = thread.meta();
169 if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
170 {
171 skipped.push(meta.name);
172 continue;
173 }
174
175 if let Some(language) = meta.language_server {
176 if !languages.contains(&language.file_extension) {
177 panic!(
178 "Eval for {:?} could not be run because no language server was found for extension {:?}",
179 meta.name,
180 language.file_extension
181 );
182 }
183 }
184
185 // TODO: This creates a worktree per repetition. Ideally these examples should
186 // either be run sequentially on the same worktree, or reuse worktrees when there
187 // are more examples to run than the concurrency limit.
188 for repetition_number in 0..args.repetitions {
189 let example_instance = ExampleInstance::new(
190 thread.clone(),
191 &repos_dir,
192 &run_dir,
193 &worktrees_dir,
194 repetition_number,
195 );
196
197 examples.push(example_instance);
198 }
199 }
200
201 if !skipped.is_empty() {
202 println!("Skipped threads: {}", skipped.join(", "));
203 }
204
205 if examples.is_empty() {
206 eprintln!("Filter matched no examples");
207 return cx.update(|cx| cx.quit());
208 }
209
210 let mut repo_urls = HashSet::default();
211 let mut clone_tasks = Vec::new();
212
213 let max_name_width = examples
214 .iter()
215 .map(|e| e.worktree_name().len())
216 .max()
217 .unwrap_or(0);
218
219 for (i, example_instance) in examples.iter_mut().enumerate() {
220 let color = COLORS[i % COLORS.len()].to_string();
221 example_instance.set_log_prefix_style(&color, max_name_width);
222
223 println!(
224 "{}Logging to: {}",
225 example_instance.log_prefix,
226 example_instance.run_directory.display()
227 );
228
229 let repo_url = example_instance.repo_url();
230 if repo_urls.insert(repo_url.clone()) {
231 let repo_path = example_instance.repo_path.clone();
232
233 if !repo_path.join(".git").is_dir() {
234 println!(
235 "{:<width$} < {}",
236 "↓ Cloning",
237 repo_url,
238 width = max_name_width
239 );
240
241 let git_task = cx.spawn(async move |_cx| {
242 std::fs::create_dir_all(&repo_path)?;
243 run_git(&repo_path, &["init"]).await?;
244 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
245 });
246
247 clone_tasks.push(git_task);
248 } else {
249 println!(
250 "{:<width$} < {}",
251 "✔︎ Already cloned",
252 repo_url,
253 width = max_name_width
254 );
255
256 let actual_origin =
257 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
258 if actual_origin != repo_url {
259 return Err(anyhow!(
260 "remote origin {} does not match expected origin {}",
261 actual_origin,
262 repo_url,
263 ));
264 }
265 }
266 }
267 }
268
269 future::join_all(clone_tasks).await;
270
271 for example_instance in examples.iter_mut() {
272 example_instance.fetch().await?;
273 }
274
275 let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
276 let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
277
278 future::join_all((0..args.concurrency).map(|_| {
279 let app_state = app_state.clone();
280 let model = model.clone();
281 let zed_commit_sha = zed_commit_sha.clone();
282 let zed_branch_name = zed_branch_name.clone();
283 let run_id = run_id.clone();
284 let examples = examples.clone();
285 let results = results_by_example_name.clone();
286 cx.spawn(async move |cx| {
287 loop {
288 let Some(mut example) = examples.borrow_mut().pop_front() else {
289 break;
290 };
291 let result = async {
292 example.setup().await?;
293 let run_output = cx
294 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
295 .await?;
296 let judge_output = judge_example(
297 example.clone(),
298 model.clone(),
299 &zed_commit_sha,
300 &zed_branch_name,
301 &run_id,
302 &run_output,
303 enable_telemetry,
304 cx,
305 )
306 .await;
307 anyhow::Ok((run_output, judge_output))
308 }
309 .await;
310 results
311 .borrow_mut()
312 .entry(example.name.clone())
313 .or_insert(Vec::new())
314 .push((example.clone(), result));
315 }
316 })
317 }))
318 .await;
319
320 print_report(
321 &mut results_by_example_name.borrow_mut(),
322 &mut cumulative_tool_metrics,
323 &run_dir,
324 )?;
325
326 app_state.client.telemetry().flush_events().await;
327
328 cx.update(|cx| cx.quit())
329 })
330 .detach_and_log_err(cx);
331 });
332}
333
334/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
335pub struct AgentAppState {
336 pub languages: Arc<LanguageRegistry>,
337 pub client: Arc<Client>,
338 pub user_store: Entity<UserStore>,
339 pub fs: Arc<dyn fs::Fs>,
340 pub node_runtime: NodeRuntime,
341
342 // Additional fields not present in `workspace::AppState`.
343 pub prompt_builder: Arc<PromptBuilder>,
344}
345
346pub fn init(cx: &mut App) -> Arc<AgentAppState> {
347 release_channel::init(SemanticVersion::default(), cx);
348 gpui_tokio::init(cx);
349
350 let mut settings_store = SettingsStore::new(cx);
351 settings_store
352 .set_default_settings(settings::default_settings().as_ref(), cx)
353 .unwrap();
354 cx.set_global(settings_store);
355 client::init_settings(cx);
356
357 // Set User-Agent so we can download language servers from GitHub
358 let user_agent = format!(
359 "Zed/{} ({}; {})",
360 AppVersion::global(cx),
361 std::env::consts::OS,
362 std::env::consts::ARCH
363 );
364 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
365 let proxy_url = proxy_str
366 .as_ref()
367 .and_then(|input| input.parse().ok())
368 .or_else(read_proxy_from_env);
369 let http = {
370 let _guard = Tokio::handle(cx).enter();
371
372 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
373 .expect("could not start HTTP client")
374 };
375 cx.set_http_client(Arc::new(http));
376
377 Project::init_settings(cx);
378
379 let client = Client::production(cx);
380 cx.set_http_client(client.http_client());
381
382 let git_binary_path = None;
383 let fs = Arc::new(RealFs::new(
384 git_binary_path,
385 cx.background_executor().clone(),
386 ));
387
388 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
389 languages.set_language_server_download_dir(paths::languages_dir().clone());
390 let languages = Arc::new(languages);
391
392 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
393
394 extension::init(cx);
395
396 let (tx, rx) = async_watch::channel(None);
397 cx.observe_global::<SettingsStore>(move |cx| {
398 let settings = &ProjectSettings::get_global(cx).node;
399 let options = NodeBinaryOptions {
400 allow_path_lookup: !settings.ignore_system_version,
401 allow_binary_download: true,
402 use_paths: settings.path.as_ref().map(|node_path| {
403 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
404 let npm_path = settings
405 .npm_path
406 .as_ref()
407 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
408 (
409 node_path.clone(),
410 npm_path.unwrap_or_else(|| {
411 let base_path = PathBuf::new();
412 node_path.parent().unwrap_or(&base_path).join("npm")
413 }),
414 )
415 }),
416 };
417 tx.send(Some(options)).log_err();
418 })
419 .detach();
420 let node_runtime = NodeRuntime::new(client.http_client(), None, rx);
421
422 let extension_host_proxy = ExtensionHostProxy::global(cx);
423
424 language::init(cx);
425 debug_adapter_extension::init(extension_host_proxy.clone(), cx);
426 language_extension::init(extension_host_proxy.clone(), languages.clone());
427 language_model::init(client.clone(), cx);
428 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
429 languages::init(languages.clone(), node_runtime.clone(), cx);
430 prompt_store::init(cx);
431 let stdout_is_a_pty = false;
432 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
433 agent::init(
434 fs.clone(),
435 client.clone(),
436 prompt_builder.clone(),
437 languages.clone(),
438 cx,
439 );
440 assistant_tools::init(client.http_client(), cx);
441
442 SettingsStore::update_global(cx, |store, cx| {
443 store.set_user_settings(include_str!("../runner_settings.json"), cx)
444 })
445 .unwrap();
446
447 Arc::new(AgentAppState {
448 languages,
449 client,
450 user_store,
451 fs,
452 node_runtime,
453 prompt_builder,
454 })
455}
456
457pub fn find_model(
458 provider_id: &str,
459 model_id: &str,
460 model_registry: &LanguageModelRegistry,
461 cx: &App,
462) -> anyhow::Result<Arc<dyn LanguageModel>> {
463 let matching_models = model_registry
464 .available_models(cx)
465 .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
466 .collect::<Vec<_>>();
467
468 match matching_models.as_slice() {
469 [model] => Ok(model.clone()),
470 [] => Err(anyhow!(
471 "No language model with ID {}/{} was available. Available models: {}",
472 provider_id,
473 model_id,
474 model_registry
475 .available_models(cx)
476 .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
477 .collect::<Vec<_>>()
478 .join(", ")
479 )),
480 _ => Err(anyhow!(
481 "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
482 model_id,
483 matching_models
484 .iter()
485 .map(|model| model.provider_id().0)
486 .collect::<Vec<_>>()
487 )),
488 }
489}
490
491pub fn commit_sha_for_path(repo_path: &Path) -> String {
492 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
493}
494
495pub fn git_branch_for_path(repo_path: &Path) -> String {
496 match std::env::var("GITHUB_REF_NAME") {
497 Ok(branch) => branch,
498 Err(_) => {
499 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
500 .unwrap_or_else(|_| "unknown".to_string())
501 }
502 }
503}
504
505async fn judge_example(
506 example: ExampleInstance,
507 model: Arc<dyn LanguageModel>,
508 zed_commit_sha: &str,
509 zed_branch_name: &str,
510 run_id: &str,
511 run_output: &RunOutput,
512 enable_telemetry: bool,
513 cx: &AsyncApp,
514) -> JudgeOutput {
515 let judge_output = example.judge(model.clone(), &run_output, cx).await;
516
517 if enable_telemetry {
518 telemetry::event!(
519 "Agent Example Evaluated",
520 zed_commit_sha = zed_commit_sha,
521 zed_branch_name = zed_branch_name,
522 run_id = run_id,
523 example_name = example.name.clone(),
524 example_repetition = example.repetition,
525 diff_evaluation = judge_output.diff.clone(),
526 thread_evaluation = judge_output.thread.clone(),
527 tool_metrics = run_output.tool_metrics,
528 response_count = run_output.response_count,
529 token_usage = run_output.token_usage,
530 model = model.telemetry_id(),
531 model_provider = model.provider_id().to_string(),
532 repository_url = example.repo_url(),
533 repository_revision = example.revision(),
534 diagnostic_summary_before = run_output.diagnostic_summary_before,
535 diagnostic_summary_after = run_output.diagnostic_summary_after,
536 diagnostics_before = run_output.diagnostics_before,
537 diagnostics_after = run_output.diagnostics_after,
538 );
539 }
540
541 judge_output
542}
543
544const HEADER_WIDTH: usize = 65;
545
546fn print_h1(header: &str) {
547 println!("\n\n{:=^HEADER_WIDTH$}", "");
548 println!("{:^HEADER_WIDTH$}", header);
549 println!("{:=^HEADER_WIDTH$}\n", "");
550}
551
552fn print_h2(header: &str) {
553 println!("\n{:-^HEADER_WIDTH$}", "");
554 println!("{:^HEADER_WIDTH$}", header);
555 println!("{:-^HEADER_WIDTH$}\n", "");
556}
557
558fn print_report(
559 results_by_example_name: &mut HashMap<
560 String,
561 Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
562 >,
563 cumulative_tool_metrics: &mut ToolMetrics,
564 run_dir: &Path,
565) -> anyhow::Result<()> {
566 print_h1("EVAL RESULTS");
567
568 let mut diff_scores = Vec::new();
569 let mut thread_scores = Vec::new();
570 let mut programmatic_scores = Vec::new();
571 let mut error_count = 0;
572
573 for (example_name, results) in results_by_example_name.iter_mut() {
574 print_h2(example_name);
575
576 results.sort_unstable_by_key(|(example, _)| example.repetition);
577 let mut example_cumulative_tool_metrics = ToolMetrics::default();
578
579 let mut table_rows = String::new();
580
581 for (example, result) in results.iter() {
582 match result {
583 Err(err) => {
584 display_error_row(&mut table_rows, example.repetition, err.to_string())?;
585 error_count += 1;
586 programmatic_scores.push(0.0);
587 diff_scores.push(0.0);
588 thread_scores.push(0.0);
589 }
590 Ok((run_output, judge_output)) => {
591 cumulative_tool_metrics.merge(&run_output.tool_metrics);
592 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
593
594 if run_output.programmatic_assertions.total_count() > 0 {
595 for assertion in &run_output.programmatic_assertions.ran {
596 assertions::display_table_row(
597 &mut table_rows,
598 example.repetition,
599 assertion,
600 )?;
601 }
602
603 programmatic_scores
604 .push(run_output.programmatic_assertions.passed_percentage())
605 }
606
607 if !judge_output.diff.is_empty() {
608 diff_scores.push(judge_output.diff.passed_percentage());
609
610 for assertion in &judge_output.diff.ran {
611 assertions::display_table_row(
612 &mut table_rows,
613 example.repetition,
614 assertion,
615 )?;
616 }
617 }
618
619 if !judge_output.thread.is_empty() {
620 thread_scores.push(judge_output.thread.passed_percentage());
621
622 for assertion in &judge_output.thread.ran {
623 assertions::display_table_row(
624 &mut table_rows,
625 example.repetition,
626 assertion,
627 )?;
628 }
629 }
630 }
631 }
632 }
633
634 let mut all_asserts = Vec::new();
635
636 if !table_rows.is_empty() {
637 assertions::print_table_header();
638 print!("{}", table_rows);
639
640 assertions::print_table_divider();
641
642 for (example, result) in results.iter() {
643 if let Ok((run_output, judge_output)) = result {
644 let asserts = [
645 run_output.programmatic_assertions.clone(),
646 judge_output.diff.clone(),
647 judge_output.thread.clone(),
648 ];
649 all_asserts.extend_from_slice(&asserts);
650 assertions::print_table_round_summary(
651 &example.repetition.to_string(),
652 asserts.iter(),
653 )
654 } else if let Err(err) = result {
655 let assert = AssertionsReport::error(err.to_string());
656 all_asserts.push(assert.clone());
657 assertions::print_table_round_summary(
658 &example.repetition.to_string(),
659 [assert].iter(),
660 )
661 }
662 }
663
664 assertions::print_table_divider();
665
666 assertions::print_table_round_summary("avg", all_asserts.iter());
667
668 assertions::print_table_footer();
669 }
670
671 if !example_cumulative_tool_metrics.is_empty() {
672 println!("{}", &example_cumulative_tool_metrics);
673 }
674 }
675
676 if results_by_example_name.len() > 1 {
677 print_h1("AGGREGATE");
678
679 if error_count > 0 {
680 println!("\n{error_count} examples failed to run!");
681 }
682
683 let programmatic_score_count = programmatic_scores.len();
684 if programmatic_score_count > 0 {
685 let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
686 / (programmatic_score_count as f32))
687 .floor();
688 println!("Average programmatic score: {average_programmatic_score}%");
689 }
690
691 let diff_score_count = diff_scores.len();
692 if diff_score_count > 0 {
693 let average_diff_score =
694 (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
695 println!("Average diff score: {average_diff_score}%");
696 }
697
698 let thread_score_count = thread_scores.len();
699
700 if thread_score_count > 0 {
701 let average_thread_score =
702 (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
703 println!("Average thread score: {average_thread_score}%");
704 }
705
706 println!("");
707
708 print_h2("CUMULATIVE TOOL METRICS");
709 println!("{}", cumulative_tool_metrics);
710 }
711
712 let explorer_output_path = run_dir.join("overview.html");
713 let mut json_paths: Vec<PathBuf> = results_by_example_name
714 .values()
715 .flat_map(|results| {
716 results.iter().map(|(example, _)| {
717 let absolute_path = run_dir.join(example.run_directory.join("last.messages.json"));
718 let cwd = std::env::current_dir().expect("Can't get current dir");
719 pathdiff::diff_paths(&absolute_path, cwd).unwrap_or_else(|| absolute_path.clone())
720 })
721 })
722 .collect::<Vec<_>>();
723 json_paths.sort();
724 if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
725 eprintln!("Failed to generate explorer HTML: {}", err);
726 }
727
728 Ok(())
729}