1mod assertions;
2mod example;
3mod examples;
4mod explorer;
5mod ids;
6mod instance;
7mod tool_metrics;
8
9use assertions::{AssertionsReport, display_error_row};
10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
11pub(crate) use tool_metrics::*;
12
13use ::fs::RealFs;
14use clap::Parser;
15use client::{Client, ProxySettings, UserStore};
16use collections::{HashMap, HashSet};
17use extension::ExtensionHostProxy;
18use futures::future;
19use gpui::http_client::read_proxy_from_env;
20use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
21use gpui_tokio::Tokio;
22use language::LanguageRegistry;
23use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
24use node_runtime::{NodeBinaryOptions, NodeRuntime};
25use project::Project;
26use project::project_settings::ProjectSettings;
27use prompt_store::PromptBuilder;
28use release_channel::AppVersion;
29use reqwest_client::ReqwestClient;
30use settings::{Settings, SettingsStore};
31use std::cell::RefCell;
32use std::collections::VecDeque;
33use std::env;
34use std::path::{Path, PathBuf};
35use std::rc::Rc;
36use std::sync::{Arc, LazyLock};
37use util::ResultExt as _;
38
39static CARGO_MANIFEST_DIR: LazyLock<PathBuf> =
40 LazyLock::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")));
41
42#[derive(Parser, Debug)]
43#[command(name = "eval", disable_version_flag = true)]
44struct Args {
45 /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
46 #[arg(value_name = "EXAMPLE_SUBSTRING")]
47 filter: Vec<String>,
48 /// ID of model to use.
49 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
50 model: String,
51 /// Model provider to use.
52 #[arg(long, default_value = "anthropic")]
53 provider: String,
54 #[arg(long, value_delimiter = ',', default_value = "rs,ts,py")]
55 languages: Vec<String>,
56 /// How many times to run each example.
57 #[arg(long, default_value = "8")]
58 repetitions: usize,
59 /// Maximum number of examples to run concurrently.
60 #[arg(long, default_value = "4")]
61 concurrency: usize,
62}
63
64fn main() {
65 dotenv::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
66
67 env_logger::init();
68
69 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
70 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
71 let session_id = uuid::Uuid::new_v4().to_string();
72 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
73 let run_id = match env::var("GITHUB_RUN_ID") {
74 Ok(run_id) => format!("github/{}", run_id),
75 Err(_) => format!("local/{}", run_timestamp),
76 };
77
78 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
79 .parent()
80 .unwrap()
81 .parent()
82 .unwrap()
83 .canonicalize()
84 .unwrap();
85 let eval_crate_dir = root_dir.join("crates").join("eval");
86 let repos_dir = eval_crate_dir.join("repos");
87 let worktrees_dir = eval_crate_dir.join("worktrees");
88 let examples_dir = eval_crate_dir.join("src").join("examples");
89 let run_dir = eval_crate_dir
90 .join("runs")
91 .join(format!("{}", run_timestamp));
92 std::fs::create_dir_all(&run_dir).unwrap();
93 std::fs::create_dir_all(&repos_dir).unwrap();
94 std::fs::create_dir_all(&worktrees_dir).unwrap();
95 std::fs::create_dir_all(&examples_dir).unwrap();
96 std::fs::create_dir_all(&paths::config_dir()).unwrap();
97
98 let zed_commit_sha = commit_sha_for_path(&root_dir);
99 let zed_branch_name = git_branch_for_path(&root_dir);
100 let args = Args::parse();
101 let languages: HashSet<String> = args.languages.into_iter().collect();
102
103 let http_client = Arc::new(ReqwestClient::new());
104 let app = Application::headless().with_http_client(http_client.clone());
105 let all_threads = examples::all(&examples_dir);
106
107 app.run(move |cx| {
108 let app_state = init(cx);
109
110 let telemetry = app_state.client.telemetry();
111 telemetry.start(system_id, installation_id, session_id, cx);
112
113 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
114 && telemetry.has_checksum_seed();
115 if enable_telemetry {
116 println!("Telemetry enabled");
117 telemetry::event!(
118 "Agent Eval Started",
119 zed_commit_sha = zed_commit_sha,
120 zed_branch_name = zed_branch_name,
121 run_id = run_id,
122 );
123 }
124
125 let mut cumulative_tool_metrics = ToolMetrics::default();
126
127 let model_registry = LanguageModelRegistry::read_global(cx);
128 let model = find_model(&args.provider, &args.model, model_registry, cx).unwrap();
129 let model_provider_id = model.provider_id();
130 let model_provider = model_registry.provider(&model_provider_id).unwrap();
131
132 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
133 registry.set_default_model(
134 Some(ConfiguredModel {
135 provider: model_provider.clone(),
136 model: model.clone(),
137 }),
138 cx,
139 );
140 });
141
142 let authenticate_task = model_provider.authenticate(cx);
143
144 cx.spawn(async move |cx| {
145 authenticate_task.await.unwrap();
146
147 let mut examples = Vec::new();
148
149 const COLORS: [&str; 12] = [
150 "\x1b[31m", // Red
151 "\x1b[32m", // Green
152 "\x1b[33m", // Yellow
153 "\x1b[34m", // Blue
154 "\x1b[35m", // Magenta
155 "\x1b[36m", // Cyan
156 "\x1b[91m", // Bright Red
157 "\x1b[92m", // Bright Green
158 "\x1b[93m", // Bright Yellow
159 "\x1b[94m", // Bright Blue
160 "\x1b[95m", // Bright Magenta
161 "\x1b[96m", // Bright Cyan
162 ];
163
164 let mut skipped = Vec::new();
165
166 for thread in all_threads {
167 let meta = thread.meta();
168 if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
169 {
170 skipped.push(meta.name);
171 continue;
172 }
173
174 if let Some(language) = meta.language_server {
175 if !languages.contains(&language.file_extension) {
176 panic!(
177 "Eval for {:?} could not be run because no language server was found for extension {:?}",
178 meta.name,
179 language.file_extension
180 );
181 }
182 }
183
184 // TODO: This creates a worktree per repetition. Ideally these examples should
185 // either be run sequentially on the same worktree, or reuse worktrees when there
186 // are more examples to run than the concurrency limit.
187 for repetition_number in 0..args.repetitions {
188 let example_instance = ExampleInstance::new(
189 thread.clone(),
190 &repos_dir,
191 &run_dir,
192 &worktrees_dir,
193 repetition_number,
194 );
195
196 examples.push(example_instance);
197 }
198 }
199
200 if !skipped.is_empty() {
201 println!("Skipped threads: {}", skipped.join(", "));
202 }
203
204 if examples.is_empty() {
205 eprintln!("Filter matched no examples");
206 return cx.update(|cx| cx.quit());
207 }
208
209 let mut repo_urls = HashSet::default();
210 let mut clone_tasks = Vec::new();
211
212 let max_name_width = examples
213 .iter()
214 .map(|e| e.worktree_name().len())
215 .max()
216 .unwrap_or(0);
217
218 for (i, example_instance) in examples.iter_mut().enumerate() {
219 let color = COLORS[i % COLORS.len()].to_string();
220 example_instance.set_log_prefix_style(&color, max_name_width);
221
222 println!(
223 "{}Logging to: {}",
224 example_instance.log_prefix,
225 example_instance.run_directory.display()
226 );
227
228 let repo_url = example_instance.repo_url();
229 if repo_urls.insert(repo_url.clone()) {
230 let repo_path = example_instance.repo_path.clone();
231
232 if !repo_path.join(".git").is_dir() {
233 println!(
234 "{:<width$} < {}",
235 "↓ Cloning",
236 repo_url,
237 width = max_name_width
238 );
239
240 let git_task = cx.spawn(async move |_cx| {
241 std::fs::create_dir_all(&repo_path)?;
242 run_git(&repo_path, &["init"]).await?;
243 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
244 });
245
246 clone_tasks.push(git_task);
247 } else {
248 println!(
249 "{:<width$} < {}",
250 "✔︎ Already cloned",
251 repo_url,
252 width = max_name_width
253 );
254
255 let actual_origin =
256 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
257 anyhow::ensure!(
258 actual_origin == repo_url,
259 "remote origin {actual_origin} does not match expected origin {repo_url}"
260 );
261 }
262 }
263 }
264
265 future::join_all(clone_tasks).await;
266
267 for example_instance in examples.iter_mut() {
268 example_instance.fetch().await?;
269 }
270
271 let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
272 let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
273
274 future::join_all((0..args.concurrency).map(|_| {
275 let app_state = app_state.clone();
276 let model = model.clone();
277 let zed_commit_sha = zed_commit_sha.clone();
278 let zed_branch_name = zed_branch_name.clone();
279 let run_id = run_id.clone();
280 let examples = examples.clone();
281 let results = results_by_example_name.clone();
282 cx.spawn(async move |cx| {
283 loop {
284 let Some(mut example) = examples.borrow_mut().pop_front() else {
285 break;
286 };
287 let result = async {
288 example.setup().await?;
289 let run_output = cx
290 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
291 .await?;
292 let judge_output = judge_example(
293 example.clone(),
294 model.clone(),
295 &zed_commit_sha,
296 &zed_branch_name,
297 &run_id,
298 &run_output,
299 enable_telemetry,
300 cx,
301 )
302 .await;
303 anyhow::Ok((run_output, judge_output))
304 }
305 .await;
306 results
307 .borrow_mut()
308 .entry(example.name.clone())
309 .or_insert(Vec::new())
310 .push((example.clone(), result));
311 }
312 })
313 }))
314 .await;
315
316 print_report(
317 &mut results_by_example_name.borrow_mut(),
318 &mut cumulative_tool_metrics,
319 &run_dir,
320 )?;
321
322 app_state.client.telemetry().flush_events().await;
323
324 cx.update(|cx| cx.quit())
325 })
326 .detach_and_log_err(cx);
327 });
328}
329
330/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
331pub struct AgentAppState {
332 pub languages: Arc<LanguageRegistry>,
333 pub client: Arc<Client>,
334 pub user_store: Entity<UserStore>,
335 pub fs: Arc<dyn fs::Fs>,
336 pub node_runtime: NodeRuntime,
337
338 // Additional fields not present in `workspace::AppState`.
339 pub prompt_builder: Arc<PromptBuilder>,
340}
341
342pub fn init(cx: &mut App) -> Arc<AgentAppState> {
343 release_channel::init(SemanticVersion::default(), cx);
344 gpui_tokio::init(cx);
345
346 let mut settings_store = SettingsStore::new(cx);
347 settings_store
348 .set_default_settings(settings::default_settings().as_ref(), cx)
349 .unwrap();
350 cx.set_global(settings_store);
351 client::init_settings(cx);
352
353 // Set User-Agent so we can download language servers from GitHub
354 let user_agent = format!(
355 "Zed/{} ({}; {})",
356 AppVersion::global(cx),
357 std::env::consts::OS,
358 std::env::consts::ARCH
359 );
360 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
361 let proxy_url = proxy_str
362 .as_ref()
363 .and_then(|input| input.parse().ok())
364 .or_else(read_proxy_from_env);
365 let http = {
366 let _guard = Tokio::handle(cx).enter();
367
368 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
369 .expect("could not start HTTP client")
370 };
371 cx.set_http_client(Arc::new(http));
372
373 Project::init_settings(cx);
374
375 let client = Client::production(cx);
376 cx.set_http_client(client.http_client());
377
378 let git_binary_path = None;
379 let fs = Arc::new(RealFs::new(
380 git_binary_path,
381 cx.background_executor().clone(),
382 ));
383
384 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
385 languages.set_language_server_download_dir(paths::languages_dir().clone());
386 let languages = Arc::new(languages);
387
388 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
389
390 extension::init(cx);
391
392 let (tx, rx) = async_watch::channel(None);
393 cx.observe_global::<SettingsStore>(move |cx| {
394 let settings = &ProjectSettings::get_global(cx).node;
395 let options = NodeBinaryOptions {
396 allow_path_lookup: !settings.ignore_system_version,
397 allow_binary_download: true,
398 use_paths: settings.path.as_ref().map(|node_path| {
399 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
400 let npm_path = settings
401 .npm_path
402 .as_ref()
403 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
404 (
405 node_path.clone(),
406 npm_path.unwrap_or_else(|| {
407 let base_path = PathBuf::new();
408 node_path.parent().unwrap_or(&base_path).join("npm")
409 }),
410 )
411 }),
412 };
413 tx.send(Some(options)).log_err();
414 })
415 .detach();
416 let node_runtime = NodeRuntime::new(client.http_client(), None, rx);
417
418 let extension_host_proxy = ExtensionHostProxy::global(cx);
419
420 language::init(cx);
421 debug_adapter_extension::init(extension_host_proxy.clone(), cx);
422 language_extension::init(extension_host_proxy.clone(), languages.clone());
423 language_model::init(client.clone(), cx);
424 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
425 languages::init(languages.clone(), node_runtime.clone(), cx);
426 prompt_store::init(cx);
427 terminal_view::init(cx);
428 let stdout_is_a_pty = false;
429 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
430 agent::init(
431 fs.clone(),
432 client.clone(),
433 prompt_builder.clone(),
434 languages.clone(),
435 true,
436 cx,
437 );
438 assistant_tools::init(client.http_client(), cx);
439
440 SettingsStore::update_global(cx, |store, cx| {
441 store.set_user_settings(include_str!("../runner_settings.json"), cx)
442 })
443 .unwrap();
444
445 Arc::new(AgentAppState {
446 languages,
447 client,
448 user_store,
449 fs,
450 node_runtime,
451 prompt_builder,
452 })
453}
454
455pub fn find_model(
456 provider_id: &str,
457 model_id: &str,
458 model_registry: &LanguageModelRegistry,
459 cx: &App,
460) -> anyhow::Result<Arc<dyn LanguageModel>> {
461 let matching_models = model_registry
462 .available_models(cx)
463 .filter(|model| model.id().0 == model_id && model.provider_id().0 == provider_id)
464 .collect::<Vec<_>>();
465
466 match matching_models.as_slice() {
467 [model] => Ok(model.clone()),
468 [] => anyhow::bail!(
469 "No language model with ID {}/{} was available. Available models: {}",
470 provider_id,
471 model_id,
472 model_registry
473 .available_models(cx)
474 .map(|model| format!("{}/{}", model.provider_id().0, model.id().0))
475 .collect::<Vec<_>>()
476 .join(", ")
477 ),
478 _ => anyhow::bail!(
479 "Multiple language models with ID {} available - use `--provider` to choose one of: {:?}",
480 model_id,
481 matching_models
482 .iter()
483 .map(|model| model.provider_id().0)
484 .collect::<Vec<_>>()
485 ),
486 }
487}
488
489pub fn commit_sha_for_path(repo_path: &Path) -> String {
490 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
491}
492
493pub fn git_branch_for_path(repo_path: &Path) -> String {
494 match std::env::var("GITHUB_REF_NAME") {
495 Ok(branch) => branch,
496 Err(_) => {
497 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
498 .unwrap_or_else(|_| "unknown".to_string())
499 }
500 }
501}
502
503async fn judge_example(
504 example: ExampleInstance,
505 model: Arc<dyn LanguageModel>,
506 zed_commit_sha: &str,
507 zed_branch_name: &str,
508 run_id: &str,
509 run_output: &RunOutput,
510 enable_telemetry: bool,
511 cx: &AsyncApp,
512) -> JudgeOutput {
513 let judge_output = example.judge(model.clone(), &run_output, cx).await;
514
515 if enable_telemetry {
516 telemetry::event!(
517 "Agent Example Evaluated",
518 zed_commit_sha = zed_commit_sha,
519 zed_branch_name = zed_branch_name,
520 run_id = run_id,
521 example_name = example.name.clone(),
522 example_repetition = example.repetition,
523 diff_evaluation = judge_output.diff.clone(),
524 thread_evaluation = judge_output.thread.clone(),
525 tool_metrics = run_output.tool_metrics,
526 response_count = run_output.response_count,
527 token_usage = run_output.token_usage,
528 model = model.telemetry_id(),
529 model_provider = model.provider_id().to_string(),
530 repository_url = example.repo_url(),
531 repository_revision = example.revision(),
532 diagnostic_summary_before = run_output.diagnostic_summary_before,
533 diagnostic_summary_after = run_output.diagnostic_summary_after,
534 diagnostics_before = run_output.diagnostics_before,
535 diagnostics_after = run_output.diagnostics_after,
536 );
537 }
538
539 judge_output
540}
541
542const HEADER_WIDTH: usize = 65;
543
544fn print_h1(header: &str) {
545 println!("\n\n{:=^HEADER_WIDTH$}", "");
546 println!("{:^HEADER_WIDTH$}", header);
547 println!("{:=^HEADER_WIDTH$}\n", "");
548}
549
550fn print_h2(header: &str) {
551 println!("\n{:-^HEADER_WIDTH$}", "");
552 println!("{:^HEADER_WIDTH$}", header);
553 println!("{:-^HEADER_WIDTH$}\n", "");
554}
555
556fn print_report(
557 results_by_example_name: &mut HashMap<
558 String,
559 Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
560 >,
561 cumulative_tool_metrics: &mut ToolMetrics,
562 run_dir: &Path,
563) -> anyhow::Result<()> {
564 print_h1("EVAL RESULTS");
565
566 let mut diff_scores = Vec::new();
567 let mut thread_scores = Vec::new();
568 let mut programmatic_scores = Vec::new();
569 let mut error_count = 0;
570
571 for (example_name, results) in results_by_example_name.iter_mut() {
572 print_h2(example_name);
573
574 results.sort_unstable_by_key(|(example, _)| example.repetition);
575 let mut example_cumulative_tool_metrics = ToolMetrics::default();
576
577 let mut table_rows = String::new();
578
579 for (example, result) in results.iter() {
580 match result {
581 Err(err) => {
582 display_error_row(&mut table_rows, example.repetition, err.to_string())?;
583 error_count += 1;
584 programmatic_scores.push(0.0);
585 diff_scores.push(0.0);
586 thread_scores.push(0.0);
587 }
588 Ok((run_output, judge_output)) => {
589 cumulative_tool_metrics.merge(&run_output.tool_metrics);
590 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
591
592 if run_output.programmatic_assertions.total_count() > 0 {
593 for assertion in &run_output.programmatic_assertions.ran {
594 assertions::display_table_row(
595 &mut table_rows,
596 example.repetition,
597 assertion,
598 )?;
599 }
600
601 programmatic_scores
602 .push(run_output.programmatic_assertions.passed_percentage())
603 }
604
605 if !judge_output.diff.is_empty() {
606 diff_scores.push(judge_output.diff.passed_percentage());
607
608 for assertion in &judge_output.diff.ran {
609 assertions::display_table_row(
610 &mut table_rows,
611 example.repetition,
612 assertion,
613 )?;
614 }
615 }
616
617 if !judge_output.thread.is_empty() {
618 thread_scores.push(judge_output.thread.passed_percentage());
619
620 for assertion in &judge_output.thread.ran {
621 assertions::display_table_row(
622 &mut table_rows,
623 example.repetition,
624 assertion,
625 )?;
626 }
627 }
628 }
629 }
630 }
631
632 let mut all_asserts = Vec::new();
633
634 if !table_rows.is_empty() {
635 assertions::print_table_header();
636 print!("{}", table_rows);
637
638 assertions::print_table_divider();
639
640 for (example, result) in results.iter() {
641 if let Ok((run_output, judge_output)) = result {
642 let asserts = [
643 run_output.programmatic_assertions.clone(),
644 judge_output.diff.clone(),
645 judge_output.thread.clone(),
646 ];
647 all_asserts.extend_from_slice(&asserts);
648 assertions::print_table_round_summary(
649 &example.repetition.to_string(),
650 asserts.iter(),
651 )
652 } else if let Err(err) = result {
653 let assert = AssertionsReport::error(err.to_string());
654 all_asserts.push(assert.clone());
655 assertions::print_table_round_summary(
656 &example.repetition.to_string(),
657 [assert].iter(),
658 )
659 }
660 }
661
662 assertions::print_table_divider();
663
664 assertions::print_table_round_summary("avg", all_asserts.iter());
665
666 assertions::print_table_footer();
667 }
668
669 if !example_cumulative_tool_metrics.is_empty() {
670 println!("{}", &example_cumulative_tool_metrics);
671 }
672 }
673
674 if results_by_example_name.len() > 1 {
675 print_h1("AGGREGATE");
676
677 if error_count > 0 {
678 println!("\n{error_count} examples failed to run!");
679 }
680
681 let programmatic_score_count = programmatic_scores.len();
682 if programmatic_score_count > 0 {
683 let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
684 / (programmatic_score_count as f32))
685 .floor();
686 println!("Average programmatic score: {average_programmatic_score}%");
687 }
688
689 let diff_score_count = diff_scores.len();
690 if diff_score_count > 0 {
691 let average_diff_score =
692 (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
693 println!("Average diff score: {average_diff_score}%");
694 }
695
696 let thread_score_count = thread_scores.len();
697
698 if thread_score_count > 0 {
699 let average_thread_score =
700 (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
701 println!("Average thread score: {average_thread_score}%");
702 }
703
704 println!("");
705
706 print_h2("CUMULATIVE TOOL METRICS");
707 println!("{}", cumulative_tool_metrics);
708 }
709
710 let explorer_output_path = run_dir.join("overview.html");
711 let mut json_paths: Vec<PathBuf> = results_by_example_name
712 .values()
713 .flat_map(|results| {
714 results.iter().map(|(example, _)| {
715 let absolute_path = run_dir.join(example.run_directory.join("last.messages.json"));
716 let cwd = std::env::current_dir().expect("Can't get current dir");
717 pathdiff::diff_paths(&absolute_path, cwd).unwrap_or_else(|| absolute_path.clone())
718 })
719 })
720 .collect::<Vec<_>>();
721 json_paths.sort();
722 if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
723 eprintln!("Failed to generate explorer HTML: {}", err);
724 }
725
726 Ok(())
727}