1mod assertions;
2mod example;
3mod examples;
4mod explorer;
5mod ids;
6mod instance;
7mod tool_metrics;
8
9use assertions::display_error_row;
10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
11pub(crate) use tool_metrics::*;
12
13use ::fs::RealFs;
14use anyhow::anyhow;
15use clap::Parser;
16use client::{Client, ProxySettings, UserStore};
17use collections::{HashMap, HashSet};
18use extension::ExtensionHostProxy;
19use futures::future;
20use gpui::http_client::{Uri, read_proxy_from_env};
21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
22use gpui_tokio::Tokio;
23use language::LanguageRegistry;
24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
25use node_runtime::{NodeBinaryOptions, NodeRuntime};
26use project::Project;
27use project::project_settings::ProjectSettings;
28use prompt_store::PromptBuilder;
29use release_channel::AppVersion;
30use reqwest_client::ReqwestClient;
31use settings::{Settings, SettingsStore};
32use std::cell::RefCell;
33use std::collections::VecDeque;
34use std::env;
35use std::path::{Path, PathBuf};
36use std::rc::Rc;
37use std::sync::Arc;
38use util::ResultExt as _;
39
40#[derive(Parser, Debug)]
41#[command(name = "eval", disable_version_flag = true)]
42struct Args {
43 /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
44 #[arg(value_name = "EXAMPLE_SUBSTRING")]
45 filter: Vec<String>,
46 /// Model to use (default: "claude-3-7-sonnet-latest")
47 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
48 model: String,
49 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
50 languages: Vec<String>,
51 /// How many times to run each example.
52 #[arg(long, default_value = "1")]
53 repetitions: usize,
54 /// Maximum number of examples to run concurrently.
55 #[arg(long, default_value = "10")]
56 concurrency: usize,
57}
58
59fn main() {
60 env_logger::init();
61
62 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
63 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
64 let session_id = uuid::Uuid::new_v4().to_string();
65 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
66 let run_id = match env::var("GITHUB_RUN_ID") {
67 Ok(run_id) => format!("github/{}", run_id),
68 Err(_) => format!("local/{}", run_timestamp),
69 };
70
71 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
72 .parent()
73 .unwrap()
74 .parent()
75 .unwrap()
76 .canonicalize()
77 .unwrap();
78 let eval_crate_dir = root_dir.join("crates").join("eval");
79 let repos_dir = eval_crate_dir.join("repos");
80 let worktrees_dir = eval_crate_dir.join("worktrees");
81 let examples_dir = eval_crate_dir.join("src").join("examples");
82 let run_dir = eval_crate_dir
83 .join("runs")
84 .join(format!("{}", run_timestamp));
85 std::fs::create_dir_all(&run_dir).unwrap();
86 std::fs::create_dir_all(&repos_dir).unwrap();
87 std::fs::create_dir_all(&worktrees_dir).unwrap();
88 std::fs::create_dir_all(&examples_dir).unwrap();
89 std::fs::create_dir_all(&paths::config_dir()).unwrap();
90
91 let zed_commit_sha = commit_sha_for_path(&root_dir);
92 let zed_branch_name = git_branch_for_path(&root_dir);
93 let args = Args::parse();
94 let languages: HashSet<String> = args.languages.into_iter().collect();
95
96 let http_client = Arc::new(ReqwestClient::new());
97 let app = Application::headless().with_http_client(http_client.clone());
98 let all_threads = examples::all(&examples_dir);
99
100 app.run(move |cx| {
101 let app_state = init(cx);
102
103 let telemetry = app_state.client.telemetry();
104 telemetry.start(system_id, installation_id, session_id, cx);
105
106 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
107 && telemetry.has_checksum_seed();
108 if enable_telemetry {
109 println!("Telemetry enabled");
110 telemetry::event!(
111 "Agent Eval Started",
112 zed_commit_sha = zed_commit_sha,
113 zed_branch_name = zed_branch_name,
114 run_id = run_id,
115 );
116 }
117
118 let mut cumulative_tool_metrics = ToolMetrics::default();
119
120 let model_registry = LanguageModelRegistry::read_global(cx);
121 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
122 let model_provider_id = model.provider_id();
123 let model_provider = model_registry.provider(&model_provider_id).unwrap();
124
125 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
126 registry.set_default_model(
127 Some(ConfiguredModel {
128 provider: model_provider.clone(),
129 model: model.clone(),
130 }),
131 cx,
132 );
133 });
134
135 let authenticate_task = model_provider.authenticate(cx);
136
137 cx.spawn(async move |cx| {
138 authenticate_task.await.unwrap();
139
140 let mut examples = Vec::new();
141
142 const COLORS: [&str; 12] = [
143 "\x1b[31m", // Red
144 "\x1b[32m", // Green
145 "\x1b[33m", // Yellow
146 "\x1b[34m", // Blue
147 "\x1b[35m", // Magenta
148 "\x1b[36m", // Cyan
149 "\x1b[91m", // Bright Red
150 "\x1b[92m", // Bright Green
151 "\x1b[93m", // Bright Yellow
152 "\x1b[94m", // Bright Blue
153 "\x1b[95m", // Bright Magenta
154 "\x1b[96m", // Bright Cyan
155 ];
156
157 let mut skipped = Vec::new();
158
159 for thread in all_threads {
160 let meta = thread.meta();
161 if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
162 {
163 skipped.push(meta.name);
164 continue;
165 }
166
167 if meta.language_server.map_or(false, |language| {
168 !languages.contains(&language.file_extension)
169 }) {
170 skipped.push(meta.name);
171 continue;
172 }
173
174 // TODO: This creates a worktree per repetition. Ideally these examples should
175 // either be run sequentially on the same worktree, or reuse worktrees when there
176 // are more examples to run than the concurrency limit.
177 for repetition_number in 0..args.repetitions {
178 let example_instance = ExampleInstance::new(
179 thread.clone(),
180 &repos_dir,
181 &run_dir,
182 &worktrees_dir,
183 repetition_number,
184 );
185
186 examples.push(example_instance);
187 }
188 }
189
190 if !skipped.is_empty() {
191 println!("Skipped threads: {}", skipped.join(", "));
192 }
193
194 if examples.is_empty() {
195 eprintln!("Filter matched no examples");
196 return cx.update(|cx| cx.quit());
197 }
198
199 let mut repo_urls = HashSet::default();
200 let mut clone_tasks = Vec::new();
201
202 let max_name_width = examples
203 .iter()
204 .map(|e| e.worktree_name().len())
205 .max()
206 .unwrap_or(0);
207
208 for (i, example_instance) in examples.iter_mut().enumerate() {
209 let color = COLORS[i % COLORS.len()].to_string();
210 example_instance.set_log_prefix_style(&color, max_name_width);
211
212 println!(
213 "{}Logging to: {}",
214 example_instance.log_prefix,
215 example_instance.run_directory.display()
216 );
217
218 let repo_url = example_instance.repo_url();
219 if repo_urls.insert(repo_url.clone()) {
220 let repo_path = example_instance.repo_path.clone();
221
222 if !repo_path.join(".git").is_dir() {
223 println!(
224 "{:<width$} < {}",
225 "↓ Cloning",
226 repo_url,
227 width = max_name_width
228 );
229
230 let git_task = cx.spawn(async move |_cx| {
231 std::fs::create_dir_all(&repo_path)?;
232 run_git(&repo_path, &["init"]).await?;
233 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
234 });
235
236 clone_tasks.push(git_task);
237 } else {
238 println!(
239 "{:<width$} < {}",
240 "✔︎ Already cloned",
241 repo_url,
242 width = max_name_width
243 );
244
245 let actual_origin =
246 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
247 if actual_origin != repo_url {
248 return Err(anyhow!(
249 "remote origin {} does not match expected origin {}",
250 actual_origin,
251 repo_url,
252 ));
253 }
254 }
255 }
256 }
257
258 future::join_all(clone_tasks).await;
259
260 for example_instance in examples.iter_mut() {
261 example_instance.fetch().await?;
262 }
263
264 let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
265 let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
266
267 future::join_all((0..args.concurrency).map(|_| {
268 let app_state = app_state.clone();
269 let model = model.clone();
270 let zed_commit_sha = zed_commit_sha.clone();
271 let zed_branch_name = zed_branch_name.clone();
272 let run_id = run_id.clone();
273 let examples = examples.clone();
274 let results = results_by_example_name.clone();
275 cx.spawn(async move |cx| {
276 loop {
277 let Some(mut example) = examples.borrow_mut().pop_front() else {
278 break;
279 };
280 let result = async {
281 example.setup().await?;
282 let run_output = cx
283 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
284 .await?;
285 let judge_output = judge_example(
286 example.clone(),
287 model.clone(),
288 &zed_commit_sha,
289 &zed_branch_name,
290 &run_id,
291 &run_output,
292 enable_telemetry,
293 cx,
294 )
295 .await;
296 anyhow::Ok((run_output, judge_output))
297 }
298 .await;
299 results
300 .borrow_mut()
301 .entry(example.name.clone())
302 .or_insert(Vec::new())
303 .push((example.clone(), result));
304 }
305 })
306 }))
307 .await;
308
309 print_report(
310 &mut results_by_example_name.borrow_mut(),
311 &mut cumulative_tool_metrics,
312 &run_dir,
313 )?;
314
315 app_state.client.telemetry().flush_events().await;
316
317 cx.update(|cx| cx.quit())
318 })
319 .detach_and_log_err(cx);
320 });
321}
322
323/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
324pub struct AgentAppState {
325 pub languages: Arc<LanguageRegistry>,
326 pub client: Arc<Client>,
327 pub user_store: Entity<UserStore>,
328 pub fs: Arc<dyn fs::Fs>,
329 pub node_runtime: NodeRuntime,
330
331 // Additional fields not present in `workspace::AppState`.
332 pub prompt_builder: Arc<PromptBuilder>,
333}
334
335pub fn init(cx: &mut App) -> Arc<AgentAppState> {
336 release_channel::init(SemanticVersion::default(), cx);
337 gpui_tokio::init(cx);
338
339 let mut settings_store = SettingsStore::new(cx);
340 settings_store
341 .set_default_settings(settings::default_settings().as_ref(), cx)
342 .unwrap();
343 cx.set_global(settings_store);
344 client::init_settings(cx);
345
346 // Set User-Agent so we can download language servers from GitHub
347 let user_agent = format!(
348 "Zed/{} ({}; {})",
349 AppVersion::global(cx),
350 std::env::consts::OS,
351 std::env::consts::ARCH
352 );
353 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
354 let proxy_url = proxy_str
355 .as_ref()
356 .and_then(|input| input.parse::<Uri>().ok())
357 .or_else(read_proxy_from_env);
358 let http = {
359 let _guard = Tokio::handle(cx).enter();
360
361 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
362 .expect("could not start HTTP client")
363 };
364 cx.set_http_client(Arc::new(http));
365
366 Project::init_settings(cx);
367
368 let client = Client::production(cx);
369 cx.set_http_client(client.http_client().clone());
370
371 let git_binary_path = None;
372 let fs = Arc::new(RealFs::new(
373 git_binary_path,
374 cx.background_executor().clone(),
375 ));
376
377 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
378 languages.set_language_server_download_dir(paths::languages_dir().clone());
379 let languages = Arc::new(languages);
380
381 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
382
383 extension::init(cx);
384
385 let (tx, rx) = async_watch::channel(None);
386 cx.observe_global::<SettingsStore>(move |cx| {
387 let settings = &ProjectSettings::get_global(cx).node;
388 let options = NodeBinaryOptions {
389 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
390 allow_binary_download: true,
391 use_paths: settings.path.as_ref().map(|node_path| {
392 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
393 let npm_path = settings
394 .npm_path
395 .as_ref()
396 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
397 (
398 node_path.clone(),
399 npm_path.unwrap_or_else(|| {
400 let base_path = PathBuf::new();
401 node_path.parent().unwrap_or(&base_path).join("npm")
402 }),
403 )
404 }),
405 };
406 tx.send(Some(options)).log_err();
407 })
408 .detach();
409 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
410
411 let extension_host_proxy = ExtensionHostProxy::global(cx);
412
413 language::init(cx);
414 language_extension::init(extension_host_proxy.clone(), languages.clone());
415 language_model::init(client.clone(), cx);
416 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
417 languages::init(languages.clone(), node_runtime.clone(), cx);
418 assistant_tools::init(client.http_client().clone(), cx);
419 context_server::init(cx);
420 prompt_store::init(cx);
421 let stdout_is_a_pty = false;
422 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
423 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
424
425 SettingsStore::update_global(cx, |store, cx| {
426 store.set_user_settings(include_str!("../runner_settings.json"), cx)
427 })
428 .unwrap();
429
430 Arc::new(AgentAppState {
431 languages,
432 client,
433 user_store,
434 fs,
435 node_runtime,
436 prompt_builder,
437 })
438}
439
440pub fn find_model(
441 model_name: &str,
442 model_registry: &LanguageModelRegistry,
443 cx: &App,
444) -> anyhow::Result<Arc<dyn LanguageModel>> {
445 let model = model_registry
446 .available_models(cx)
447 .find(|model| model.id().0 == model_name);
448
449 let Some(model) = model else {
450 return Err(anyhow!(
451 "No language model named {} was available. Available models: {}",
452 model_name,
453 model_registry
454 .available_models(cx)
455 .map(|model| model.id().0.clone())
456 .collect::<Vec<_>>()
457 .join(", ")
458 ));
459 };
460
461 Ok(model)
462}
463
464pub fn commit_sha_for_path(repo_path: &Path) -> String {
465 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
466}
467
468pub fn git_branch_for_path(repo_path: &Path) -> String {
469 match std::env::var("GITHUB_REF_NAME") {
470 Ok(branch) => branch,
471 Err(_) => {
472 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
473 .unwrap_or_else(|_| "unknown".to_string())
474 }
475 }
476}
477
478async fn judge_example(
479 example: ExampleInstance,
480 model: Arc<dyn LanguageModel>,
481 zed_commit_sha: &str,
482 zed_branch_name: &str,
483 run_id: &str,
484 run_output: &RunOutput,
485 enable_telemetry: bool,
486 cx: &AsyncApp,
487) -> JudgeOutput {
488 let judge_output = example.judge(model.clone(), &run_output, cx).await;
489
490 if enable_telemetry {
491 telemetry::event!(
492 "Agent Example Evaluated",
493 zed_commit_sha = zed_commit_sha,
494 zed_branch_name = zed_branch_name,
495 run_id = run_id,
496 example_name = example.name.clone(),
497 example_repetition = example.repetition,
498 diff_evaluation = judge_output.diff.clone(),
499 thread_evaluation = judge_output.thread.clone(),
500 tool_metrics = run_output.tool_metrics,
501 response_count = run_output.response_count,
502 token_usage = run_output.token_usage,
503 model = model.telemetry_id(),
504 model_provider = model.provider_id().to_string(),
505 repository_url = example.repo_url(),
506 repository_revision = example.revision(),
507 diagnostic_summary_before = run_output.diagnostic_summary_before,
508 diagnostic_summary_after = run_output.diagnostic_summary_after,
509 diagnostics_before = run_output.diagnostics_before,
510 diagnostics_after = run_output.diagnostics_after,
511 );
512 }
513
514 judge_output
515}
516
517const HEADER_WIDTH: usize = 65;
518
519fn print_h1(header: &str) {
520 println!("\n\n{:=^HEADER_WIDTH$}", "");
521 println!("{:^HEADER_WIDTH$}", header);
522 println!("{:=^HEADER_WIDTH$}\n", "");
523}
524
525fn print_h2(header: &str) {
526 println!("\n{:-^HEADER_WIDTH$}", "");
527 println!("{:^HEADER_WIDTH$}", header);
528 println!("{:-^HEADER_WIDTH$}\n", "");
529}
530
531fn print_report(
532 results_by_example_name: &mut HashMap<
533 String,
534 Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
535 >,
536 cumulative_tool_metrics: &mut ToolMetrics,
537 run_dir: &Path,
538) -> anyhow::Result<()> {
539 print_h1("EVAL RESULTS");
540
541 let mut diff_scores = Vec::new();
542 let mut thread_scores = Vec::new();
543 let mut programmatic_scores = Vec::new();
544 let mut error_count = 0;
545
546 for (example_name, results) in results_by_example_name.iter_mut() {
547 print_h2(example_name);
548
549 results.sort_unstable_by_key(|(example, _)| example.repetition);
550 let mut example_cumulative_tool_metrics = ToolMetrics::default();
551
552 let mut table_rows = String::new();
553
554 for (example, result) in results.iter() {
555 match result {
556 Err(err) => {
557 display_error_row(&mut table_rows, example.repetition, err.to_string())?;
558 error_count += 1;
559 }
560 Ok((run_output, judge_output)) => {
561 cumulative_tool_metrics.merge(&run_output.tool_metrics);
562 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
563
564 if !run_output.programmatic_assertions.total_count() > 0 {
565 for assertion in &run_output.programmatic_assertions.ran {
566 assertions::display_table_row(
567 &mut table_rows,
568 example.repetition,
569 assertion,
570 )?;
571 }
572
573 programmatic_scores
574 .push(run_output.programmatic_assertions.passed_percentage())
575 }
576
577 if !judge_output.diff.is_empty() {
578 diff_scores.push(judge_output.diff.passed_percentage());
579
580 for assertion in &judge_output.diff.ran {
581 assertions::display_table_row(
582 &mut table_rows,
583 example.repetition,
584 assertion,
585 )?;
586 }
587 }
588
589 if !judge_output.thread.is_empty() {
590 thread_scores.push(judge_output.thread.passed_percentage());
591
592 for assertion in &judge_output.thread.ran {
593 assertions::display_table_row(
594 &mut table_rows,
595 example.repetition,
596 assertion,
597 )?;
598 }
599 }
600 }
601 }
602 }
603
604 if !table_rows.is_empty() {
605 assertions::print_table_header();
606 print!("{}", table_rows);
607
608 assertions::print_table_divider();
609
610 for (example, result) in results.iter() {
611 if let Ok((run_output, judge_output)) = result {
612 assertions::print_table_round_summary(
613 &example.repetition.to_string(),
614 [
615 &run_output.programmatic_assertions,
616 &judge_output.diff,
617 &judge_output.thread,
618 ]
619 .into_iter(),
620 )
621 }
622 }
623
624 assertions::print_table_divider();
625
626 assertions::print_table_round_summary(
627 "avg",
628 results.iter().flat_map(|(_, result)| {
629 result.iter().flat_map(|(run_output, judge_output)| {
630 [
631 &run_output.programmatic_assertions,
632 &judge_output.diff,
633 &judge_output.thread,
634 ]
635 .into_iter()
636 })
637 }),
638 );
639
640 assertions::print_table_footer();
641 }
642
643 if !example_cumulative_tool_metrics.is_empty() {
644 println!("{}", &example_cumulative_tool_metrics);
645 }
646 }
647
648 if results_by_example_name.len() > 1 {
649 print_h1("AGGREGATE");
650
651 if error_count > 0 {
652 println!("\n{error_count} examples failed to run!");
653 }
654
655 let programmatic_score_count = programmatic_scores.len();
656 if programmatic_score_count > 0 {
657 let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
658 / (programmatic_score_count as f32))
659 .floor();
660 println!("Average programmatic score: {average_programmatic_score}%");
661 }
662
663 let diff_score_count = diff_scores.len();
664 if diff_score_count > 0 {
665 let average_diff_score =
666 (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
667 println!("Average diff score: {average_diff_score}%");
668 }
669
670 let thread_score_count = thread_scores.len();
671
672 if thread_score_count > 0 {
673 let average_thread_score =
674 (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
675 println!("Average thread score: {average_thread_score}%");
676 }
677
678 println!("");
679
680 print_h2("CUMULATIVE TOOL METRICS");
681 println!("{}", cumulative_tool_metrics);
682 }
683
684 let explorer_output_path = run_dir.join("overview.html");
685 let mut json_paths: Vec<PathBuf> = results_by_example_name
686 .values()
687 .flat_map(|results| {
688 results.iter().map(|(example, _)| {
689 let absolute_path = example.run_directory.join("last.messages.json");
690 pathdiff::diff_paths(&absolute_path, run_dir)
691 .unwrap_or_else(|| absolute_path.clone())
692 })
693 })
694 .collect::<Vec<_>>();
695 json_paths.sort();
696 if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
697 eprintln!("Failed to generate explorer HTML: {}", err);
698 }
699
700 Ok(())
701}