1mod assertions;
2mod example;
3mod examples;
4mod explorer;
5mod ids;
6mod instance;
7mod tool_metrics;
8
9use assertions::display_error_row;
10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
11pub(crate) use tool_metrics::*;
12
13use ::fs::RealFs;
14use anyhow::anyhow;
15use clap::Parser;
16use client::{Client, ProxySettings, UserStore};
17use collections::{HashMap, HashSet};
18use extension::ExtensionHostProxy;
19use futures::future;
20use gpui::http_client::read_proxy_from_env;
21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
22use gpui_tokio::Tokio;
23use language::LanguageRegistry;
24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
25use node_runtime::{NodeBinaryOptions, NodeRuntime};
26use project::Project;
27use project::project_settings::ProjectSettings;
28use prompt_store::PromptBuilder;
29use release_channel::AppVersion;
30use reqwest_client::ReqwestClient;
31use settings::{Settings, SettingsStore};
32use std::cell::RefCell;
33use std::collections::VecDeque;
34use std::env;
35use std::path::{Path, PathBuf};
36use std::rc::Rc;
37use std::sync::{Arc, LazyLock};
38use util::ResultExt as _;
39
40static CARGO_MANIFEST_DIR: LazyLock<PathBuf> =
41 LazyLock::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")));
42
43#[derive(Parser, Debug)]
44#[command(name = "eval", disable_version_flag = true)]
45struct Args {
46 /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
47 #[arg(value_name = "EXAMPLE_SUBSTRING")]
48 filter: Vec<String>,
49 /// Model to use (default: "claude-3-7-sonnet-latest")
50 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
51 model: String,
52 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
53 languages: Vec<String>,
54 /// How many times to run each example.
55 #[arg(long, default_value = "8")]
56 repetitions: usize,
57 /// Maximum number of examples to run concurrently.
58 #[arg(long, default_value = "4")]
59 concurrency: usize,
60}
61
62fn main() {
63 dotenv::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
64
65 env_logger::init();
66
67 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
68 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
69 let session_id = uuid::Uuid::new_v4().to_string();
70 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
71 let run_id = match env::var("GITHUB_RUN_ID") {
72 Ok(run_id) => format!("github/{}", run_id),
73 Err(_) => format!("local/{}", run_timestamp),
74 };
75
76 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
77 .parent()
78 .unwrap()
79 .parent()
80 .unwrap()
81 .canonicalize()
82 .unwrap();
83 let eval_crate_dir = root_dir.join("crates").join("eval");
84 let repos_dir = eval_crate_dir.join("repos");
85 let worktrees_dir = eval_crate_dir.join("worktrees");
86 let examples_dir = eval_crate_dir.join("src").join("examples");
87 let run_dir = eval_crate_dir
88 .join("runs")
89 .join(format!("{}", run_timestamp));
90 std::fs::create_dir_all(&run_dir).unwrap();
91 std::fs::create_dir_all(&repos_dir).unwrap();
92 std::fs::create_dir_all(&worktrees_dir).unwrap();
93 std::fs::create_dir_all(&examples_dir).unwrap();
94 std::fs::create_dir_all(&paths::config_dir()).unwrap();
95
96 let zed_commit_sha = commit_sha_for_path(&root_dir);
97 let zed_branch_name = git_branch_for_path(&root_dir);
98 let args = Args::parse();
99 let languages: HashSet<String> = args.languages.into_iter().collect();
100
101 let http_client = Arc::new(ReqwestClient::new());
102 let app = Application::headless().with_http_client(http_client.clone());
103 let all_threads = examples::all(&examples_dir);
104
105 app.run(move |cx| {
106 let app_state = init(cx);
107
108 let telemetry = app_state.client.telemetry();
109 telemetry.start(system_id, installation_id, session_id, cx);
110
111 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
112 && telemetry.has_checksum_seed();
113 if enable_telemetry {
114 println!("Telemetry enabled");
115 telemetry::event!(
116 "Agent Eval Started",
117 zed_commit_sha = zed_commit_sha,
118 zed_branch_name = zed_branch_name,
119 run_id = run_id,
120 );
121 }
122
123 let mut cumulative_tool_metrics = ToolMetrics::default();
124
125 let model_registry = LanguageModelRegistry::read_global(cx);
126 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
127 let model_provider_id = model.provider_id();
128 let model_provider = model_registry.provider(&model_provider_id).unwrap();
129
130 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
131 registry.set_default_model(
132 Some(ConfiguredModel {
133 provider: model_provider.clone(),
134 model: model.clone(),
135 }),
136 cx,
137 );
138 });
139
140 let authenticate_task = model_provider.authenticate(cx);
141
142 cx.spawn(async move |cx| {
143 authenticate_task.await.unwrap();
144
145 let mut examples = Vec::new();
146
147 const COLORS: [&str; 12] = [
148 "\x1b[31m", // Red
149 "\x1b[32m", // Green
150 "\x1b[33m", // Yellow
151 "\x1b[34m", // Blue
152 "\x1b[35m", // Magenta
153 "\x1b[36m", // Cyan
154 "\x1b[91m", // Bright Red
155 "\x1b[92m", // Bright Green
156 "\x1b[93m", // Bright Yellow
157 "\x1b[94m", // Bright Blue
158 "\x1b[95m", // Bright Magenta
159 "\x1b[96m", // Bright Cyan
160 ];
161
162 let mut skipped = Vec::new();
163
164 for thread in all_threads {
165 let meta = thread.meta();
166 if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
167 {
168 skipped.push(meta.name);
169 continue;
170 }
171
172 if let Some(language) = meta.language_server {
173 if !languages.contains(&language.file_extension) {
174 panic!(
175 "Eval for {:?} could not be run because no language server was found for extension {:?}",
176 meta.name,
177 language.file_extension
178 );
179 }
180 }
181
182 // TODO: This creates a worktree per repetition. Ideally these examples should
183 // either be run sequentially on the same worktree, or reuse worktrees when there
184 // are more examples to run than the concurrency limit.
185 for repetition_number in 0..args.repetitions {
186 let example_instance = ExampleInstance::new(
187 thread.clone(),
188 &repos_dir,
189 &run_dir,
190 &worktrees_dir,
191 repetition_number,
192 );
193
194 examples.push(example_instance);
195 }
196 }
197
198 if !skipped.is_empty() {
199 println!("Skipped threads: {}", skipped.join(", "));
200 }
201
202 if examples.is_empty() {
203 eprintln!("Filter matched no examples");
204 return cx.update(|cx| cx.quit());
205 }
206
207 let mut repo_urls = HashSet::default();
208 let mut clone_tasks = Vec::new();
209
210 let max_name_width = examples
211 .iter()
212 .map(|e| e.worktree_name().len())
213 .max()
214 .unwrap_or(0);
215
216 for (i, example_instance) in examples.iter_mut().enumerate() {
217 let color = COLORS[i % COLORS.len()].to_string();
218 example_instance.set_log_prefix_style(&color, max_name_width);
219
220 println!(
221 "{}Logging to: {}",
222 example_instance.log_prefix,
223 example_instance.run_directory.display()
224 );
225
226 let repo_url = example_instance.repo_url();
227 if repo_urls.insert(repo_url.clone()) {
228 let repo_path = example_instance.repo_path.clone();
229
230 if !repo_path.join(".git").is_dir() {
231 println!(
232 "{:<width$} < {}",
233 "↓ Cloning",
234 repo_url,
235 width = max_name_width
236 );
237
238 let git_task = cx.spawn(async move |_cx| {
239 std::fs::create_dir_all(&repo_path)?;
240 run_git(&repo_path, &["init"]).await?;
241 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
242 });
243
244 clone_tasks.push(git_task);
245 } else {
246 println!(
247 "{:<width$} < {}",
248 "✔︎ Already cloned",
249 repo_url,
250 width = max_name_width
251 );
252
253 let actual_origin =
254 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
255 if actual_origin != repo_url {
256 return Err(anyhow!(
257 "remote origin {} does not match expected origin {}",
258 actual_origin,
259 repo_url,
260 ));
261 }
262 }
263 }
264 }
265
266 future::join_all(clone_tasks).await;
267
268 for example_instance in examples.iter_mut() {
269 example_instance.fetch().await?;
270 }
271
272 let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
273 let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
274
275 future::join_all((0..args.concurrency).map(|_| {
276 let app_state = app_state.clone();
277 let model = model.clone();
278 let zed_commit_sha = zed_commit_sha.clone();
279 let zed_branch_name = zed_branch_name.clone();
280 let run_id = run_id.clone();
281 let examples = examples.clone();
282 let results = results_by_example_name.clone();
283 cx.spawn(async move |cx| {
284 loop {
285 let Some(mut example) = examples.borrow_mut().pop_front() else {
286 break;
287 };
288 let result = async {
289 example.setup().await?;
290 let run_output = cx
291 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
292 .await?;
293 let judge_output = judge_example(
294 example.clone(),
295 model.clone(),
296 &zed_commit_sha,
297 &zed_branch_name,
298 &run_id,
299 &run_output,
300 enable_telemetry,
301 cx,
302 )
303 .await;
304 anyhow::Ok((run_output, judge_output))
305 }
306 .await;
307 results
308 .borrow_mut()
309 .entry(example.name.clone())
310 .or_insert(Vec::new())
311 .push((example.clone(), result));
312 }
313 })
314 }))
315 .await;
316
317 print_report(
318 &mut results_by_example_name.borrow_mut(),
319 &mut cumulative_tool_metrics,
320 &run_dir,
321 )?;
322
323 app_state.client.telemetry().flush_events().await;
324
325 cx.update(|cx| cx.quit())
326 })
327 .detach_and_log_err(cx);
328 });
329}
330
331/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
332pub struct AgentAppState {
333 pub languages: Arc<LanguageRegistry>,
334 pub client: Arc<Client>,
335 pub user_store: Entity<UserStore>,
336 pub fs: Arc<dyn fs::Fs>,
337 pub node_runtime: NodeRuntime,
338
339 // Additional fields not present in `workspace::AppState`.
340 pub prompt_builder: Arc<PromptBuilder>,
341}
342
343pub fn init(cx: &mut App) -> Arc<AgentAppState> {
344 release_channel::init(SemanticVersion::default(), cx);
345 gpui_tokio::init(cx);
346
347 let mut settings_store = SettingsStore::new(cx);
348 settings_store
349 .set_default_settings(settings::default_settings().as_ref(), cx)
350 .unwrap();
351 cx.set_global(settings_store);
352 client::init_settings(cx);
353
354 // Set User-Agent so we can download language servers from GitHub
355 let user_agent = format!(
356 "Zed/{} ({}; {})",
357 AppVersion::global(cx),
358 std::env::consts::OS,
359 std::env::consts::ARCH
360 );
361 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
362 let proxy_url = proxy_str
363 .as_ref()
364 .and_then(|input| input.parse().ok())
365 .or_else(read_proxy_from_env);
366 let http = {
367 let _guard = Tokio::handle(cx).enter();
368
369 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
370 .expect("could not start HTTP client")
371 };
372 cx.set_http_client(Arc::new(http));
373
374 Project::init_settings(cx);
375
376 let client = Client::production(cx);
377 cx.set_http_client(client.http_client());
378
379 let git_binary_path = None;
380 let fs = Arc::new(RealFs::new(
381 git_binary_path,
382 cx.background_executor().clone(),
383 ));
384
385 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
386 languages.set_language_server_download_dir(paths::languages_dir().clone());
387 let languages = Arc::new(languages);
388
389 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
390
391 extension::init(cx);
392
393 let (tx, rx) = async_watch::channel(None);
394 cx.observe_global::<SettingsStore>(move |cx| {
395 let settings = &ProjectSettings::get_global(cx).node;
396 let options = NodeBinaryOptions {
397 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
398 allow_binary_download: true,
399 use_paths: settings.path.as_ref().map(|node_path| {
400 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
401 let npm_path = settings
402 .npm_path
403 .as_ref()
404 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
405 (
406 node_path.clone(),
407 npm_path.unwrap_or_else(|| {
408 let base_path = PathBuf::new();
409 node_path.parent().unwrap_or(&base_path).join("npm")
410 }),
411 )
412 }),
413 };
414 tx.send(Some(options)).log_err();
415 })
416 .detach();
417 let node_runtime = NodeRuntime::new(client.http_client(), rx);
418
419 let extension_host_proxy = ExtensionHostProxy::global(cx);
420
421 language::init(cx);
422 language_extension::init(extension_host_proxy.clone(), languages.clone());
423 language_model::init(client.clone(), cx);
424 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
425 languages::init(languages.clone(), node_runtime.clone(), cx);
426 context_server::init(cx);
427 prompt_store::init(cx);
428 let stdout_is_a_pty = false;
429 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
430 agent::init(
431 fs.clone(),
432 client.clone(),
433 prompt_builder.clone(),
434 languages.clone(),
435 cx,
436 );
437 assistant_tools::init(client.http_client(), cx);
438
439 SettingsStore::update_global(cx, |store, cx| {
440 store.set_user_settings(include_str!("../runner_settings.json"), cx)
441 })
442 .unwrap();
443
444 Arc::new(AgentAppState {
445 languages,
446 client,
447 user_store,
448 fs,
449 node_runtime,
450 prompt_builder,
451 })
452}
453
454pub fn find_model(
455 model_name: &str,
456 model_registry: &LanguageModelRegistry,
457 cx: &App,
458) -> anyhow::Result<Arc<dyn LanguageModel>> {
459 let model = model_registry
460 .available_models(cx)
461 .find(|model| model.id().0 == model_name);
462
463 let Some(model) = model else {
464 return Err(anyhow!(
465 "No language model named {} was available. Available models: {}",
466 model_name,
467 model_registry
468 .available_models(cx)
469 .map(|model| model.id().0.clone())
470 .collect::<Vec<_>>()
471 .join(", ")
472 ));
473 };
474
475 Ok(model)
476}
477
478pub fn commit_sha_for_path(repo_path: &Path) -> String {
479 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
480}
481
482pub fn git_branch_for_path(repo_path: &Path) -> String {
483 match std::env::var("GITHUB_REF_NAME") {
484 Ok(branch) => branch,
485 Err(_) => {
486 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
487 .unwrap_or_else(|_| "unknown".to_string())
488 }
489 }
490}
491
492async fn judge_example(
493 example: ExampleInstance,
494 model: Arc<dyn LanguageModel>,
495 zed_commit_sha: &str,
496 zed_branch_name: &str,
497 run_id: &str,
498 run_output: &RunOutput,
499 enable_telemetry: bool,
500 cx: &AsyncApp,
501) -> JudgeOutput {
502 let judge_output = example.judge(model.clone(), &run_output, cx).await;
503
504 if enable_telemetry {
505 telemetry::event!(
506 "Agent Example Evaluated",
507 zed_commit_sha = zed_commit_sha,
508 zed_branch_name = zed_branch_name,
509 run_id = run_id,
510 example_name = example.name.clone(),
511 example_repetition = example.repetition,
512 diff_evaluation = judge_output.diff.clone(),
513 thread_evaluation = judge_output.thread.clone(),
514 tool_metrics = run_output.tool_metrics,
515 response_count = run_output.response_count,
516 token_usage = run_output.token_usage,
517 model = model.telemetry_id(),
518 model_provider = model.provider_id().to_string(),
519 repository_url = example.repo_url(),
520 repository_revision = example.revision(),
521 diagnostic_summary_before = run_output.diagnostic_summary_before,
522 diagnostic_summary_after = run_output.diagnostic_summary_after,
523 diagnostics_before = run_output.diagnostics_before,
524 diagnostics_after = run_output.diagnostics_after,
525 );
526 }
527
528 judge_output
529}
530
531const HEADER_WIDTH: usize = 65;
532
533fn print_h1(header: &str) {
534 println!("\n\n{:=^HEADER_WIDTH$}", "");
535 println!("{:^HEADER_WIDTH$}", header);
536 println!("{:=^HEADER_WIDTH$}\n", "");
537}
538
539fn print_h2(header: &str) {
540 println!("\n{:-^HEADER_WIDTH$}", "");
541 println!("{:^HEADER_WIDTH$}", header);
542 println!("{:-^HEADER_WIDTH$}\n", "");
543}
544
545fn print_report(
546 results_by_example_name: &mut HashMap<
547 String,
548 Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
549 >,
550 cumulative_tool_metrics: &mut ToolMetrics,
551 run_dir: &Path,
552) -> anyhow::Result<()> {
553 print_h1("EVAL RESULTS");
554
555 let mut diff_scores = Vec::new();
556 let mut thread_scores = Vec::new();
557 let mut programmatic_scores = Vec::new();
558 let mut error_count = 0;
559
560 for (example_name, results) in results_by_example_name.iter_mut() {
561 print_h2(example_name);
562
563 results.sort_unstable_by_key(|(example, _)| example.repetition);
564 let mut example_cumulative_tool_metrics = ToolMetrics::default();
565
566 let mut table_rows = String::new();
567
568 for (example, result) in results.iter() {
569 match result {
570 Err(err) => {
571 display_error_row(&mut table_rows, example.repetition, err.to_string())?;
572 error_count += 1;
573 }
574 Ok((run_output, judge_output)) => {
575 cumulative_tool_metrics.merge(&run_output.tool_metrics);
576 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
577
578 if !run_output.programmatic_assertions.total_count() > 0 {
579 for assertion in &run_output.programmatic_assertions.ran {
580 assertions::display_table_row(
581 &mut table_rows,
582 example.repetition,
583 assertion,
584 )?;
585 }
586
587 programmatic_scores
588 .push(run_output.programmatic_assertions.passed_percentage())
589 }
590
591 if !judge_output.diff.is_empty() {
592 diff_scores.push(judge_output.diff.passed_percentage());
593
594 for assertion in &judge_output.diff.ran {
595 assertions::display_table_row(
596 &mut table_rows,
597 example.repetition,
598 assertion,
599 )?;
600 }
601 }
602
603 if !judge_output.thread.is_empty() {
604 thread_scores.push(judge_output.thread.passed_percentage());
605
606 for assertion in &judge_output.thread.ran {
607 assertions::display_table_row(
608 &mut table_rows,
609 example.repetition,
610 assertion,
611 )?;
612 }
613 }
614 }
615 }
616 }
617
618 if !table_rows.is_empty() {
619 assertions::print_table_header();
620 print!("{}", table_rows);
621
622 assertions::print_table_divider();
623
624 for (example, result) in results.iter() {
625 if let Ok((run_output, judge_output)) = result {
626 assertions::print_table_round_summary(
627 &example.repetition.to_string(),
628 [
629 &run_output.programmatic_assertions,
630 &judge_output.diff,
631 &judge_output.thread,
632 ]
633 .into_iter(),
634 )
635 }
636 }
637
638 assertions::print_table_divider();
639
640 assertions::print_table_round_summary(
641 "avg",
642 results.iter().flat_map(|(_, result)| {
643 result.iter().flat_map(|(run_output, judge_output)| {
644 [
645 &run_output.programmatic_assertions,
646 &judge_output.diff,
647 &judge_output.thread,
648 ]
649 .into_iter()
650 })
651 }),
652 );
653
654 assertions::print_table_footer();
655 }
656
657 if !example_cumulative_tool_metrics.is_empty() {
658 println!("{}", &example_cumulative_tool_metrics);
659 }
660 }
661
662 if results_by_example_name.len() > 1 {
663 print_h1("AGGREGATE");
664
665 if error_count > 0 {
666 println!("\n{error_count} examples failed to run!");
667 }
668
669 let programmatic_score_count = programmatic_scores.len();
670 if programmatic_score_count > 0 {
671 let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
672 / (programmatic_score_count as f32))
673 .floor();
674 println!("Average programmatic score: {average_programmatic_score}%");
675 }
676
677 let diff_score_count = diff_scores.len();
678 if diff_score_count > 0 {
679 let average_diff_score =
680 (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
681 println!("Average diff score: {average_diff_score}%");
682 }
683
684 let thread_score_count = thread_scores.len();
685
686 if thread_score_count > 0 {
687 let average_thread_score =
688 (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
689 println!("Average thread score: {average_thread_score}%");
690 }
691
692 println!("");
693
694 print_h2("CUMULATIVE TOOL METRICS");
695 println!("{}", cumulative_tool_metrics);
696 }
697
698 let explorer_output_path = run_dir.join("overview.html");
699 let mut json_paths: Vec<PathBuf> = results_by_example_name
700 .values()
701 .flat_map(|results| {
702 results.iter().map(|(example, _)| {
703 let absolute_path = example.run_directory.join("last.messages.json");
704 pathdiff::diff_paths(&absolute_path, run_dir)
705 .unwrap_or_else(|| absolute_path.clone())
706 })
707 })
708 .collect::<Vec<_>>();
709 json_paths.sort();
710 if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
711 eprintln!("Failed to generate explorer HTML: {}", err);
712 }
713
714 Ok(())
715}