1mod assertions;
2mod example;
3mod examples;
4mod ids;
5mod instance;
6mod tool_metrics;
7
8use assertions::display_error_row;
9use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
10pub(crate) use tool_metrics::*;
11
12use ::fs::RealFs;
13use anyhow::anyhow;
14use clap::Parser;
15use client::{Client, ProxySettings, UserStore};
16use collections::{HashMap, HashSet};
17use extension::ExtensionHostProxy;
18use futures::future;
19use gpui::http_client::{Uri, read_proxy_from_env};
20use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
21use gpui_tokio::Tokio;
22use language::LanguageRegistry;
23use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
24use node_runtime::{NodeBinaryOptions, NodeRuntime};
25use project::Project;
26use project::project_settings::ProjectSettings;
27use prompt_store::PromptBuilder;
28use release_channel::AppVersion;
29use reqwest_client::ReqwestClient;
30use settings::{Settings, SettingsStore};
31use std::cell::RefCell;
32use std::collections::VecDeque;
33use std::env;
34use std::path::{Path, PathBuf};
35use std::rc::Rc;
36use std::sync::Arc;
37use util::ResultExt as _;
38
39#[derive(Parser, Debug)]
40#[command(name = "eval", disable_version_flag = true)]
41struct Args {
42 /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
43 #[arg(value_name = "EXAMPLE_SUBSTRING")]
44 filter: Vec<String>,
45 /// Model to use (default: "claude-3-7-sonnet-latest")
46 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
47 model: String,
48 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
49 languages: Vec<String>,
50 /// How many times to run each example.
51 #[arg(long, default_value = "1")]
52 repetitions: usize,
53 /// Maximum number of examples to run concurrently.
54 #[arg(long, default_value = "10")]
55 concurrency: usize,
56}
57
58fn main() {
59 env_logger::init();
60
61 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
62 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
63 let session_id = uuid::Uuid::new_v4().to_string();
64 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
65 let run_id = match env::var("GITHUB_RUN_ID") {
66 Ok(run_id) => format!("github/{}", run_id),
67 Err(_) => format!("local/{}", run_timestamp),
68 };
69
70 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
71 .parent()
72 .unwrap()
73 .parent()
74 .unwrap()
75 .canonicalize()
76 .unwrap();
77 let eval_crate_dir = root_dir.join("crates").join("eval");
78 let repos_dir = eval_crate_dir.join("repos");
79 let worktrees_dir = eval_crate_dir.join("worktrees");
80 let examples_dir = eval_crate_dir.join("src").join("examples");
81 let run_dir = eval_crate_dir
82 .join("runs")
83 .join(format!("{}", run_timestamp));
84 std::fs::create_dir_all(&run_dir).unwrap();
85 std::fs::create_dir_all(&repos_dir).unwrap();
86 std::fs::create_dir_all(&worktrees_dir).unwrap();
87 std::fs::create_dir_all(&examples_dir).unwrap();
88 std::fs::create_dir_all(&paths::config_dir()).unwrap();
89
90 let zed_commit_sha = commit_sha_for_path(&root_dir);
91 let zed_branch_name = git_branch_for_path(&root_dir);
92 let args = Args::parse();
93 let languages: HashSet<String> = args.languages.into_iter().collect();
94
95 let http_client = Arc::new(ReqwestClient::new());
96 let app = Application::headless().with_http_client(http_client.clone());
97 let all_threads = examples::all(&examples_dir);
98
99 app.run(move |cx| {
100 let app_state = init(cx);
101
102 let telemetry = app_state.client.telemetry();
103 telemetry.start(system_id, installation_id, session_id, cx);
104
105 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
106 && telemetry.has_checksum_seed();
107 if enable_telemetry {
108 println!("Telemetry enabled");
109 telemetry::event!(
110 "Agent Eval Started",
111 zed_commit_sha = zed_commit_sha,
112 zed_branch_name = zed_branch_name,
113 run_id = run_id,
114 );
115 }
116
117 let mut cumulative_tool_metrics = ToolMetrics::default();
118
119 let model_registry = LanguageModelRegistry::read_global(cx);
120 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
121 let model_provider_id = model.provider_id();
122 let model_provider = model_registry.provider(&model_provider_id).unwrap();
123
124 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
125 registry.set_default_model(
126 Some(ConfiguredModel {
127 provider: model_provider.clone(),
128 model: model.clone(),
129 }),
130 cx,
131 );
132 });
133
134 let authenticate_task = model_provider.authenticate(cx);
135
136 cx.spawn(async move |cx| {
137 authenticate_task.await.unwrap();
138
139 let mut examples = Vec::new();
140
141 const COLORS: [&str; 12] = [
142 "\x1b[31m", // Red
143 "\x1b[32m", // Green
144 "\x1b[33m", // Yellow
145 "\x1b[34m", // Blue
146 "\x1b[35m", // Magenta
147 "\x1b[36m", // Cyan
148 "\x1b[91m", // Bright Red
149 "\x1b[92m", // Bright Green
150 "\x1b[93m", // Bright Yellow
151 "\x1b[94m", // Bright Blue
152 "\x1b[95m", // Bright Magenta
153 "\x1b[96m", // Bright Cyan
154 ];
155
156 let mut skipped = Vec::new();
157
158 for thread in all_threads {
159 let meta = thread.meta();
160 if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
161 {
162 skipped.push(meta.name);
163 continue;
164 }
165
166 if meta.language_server.map_or(false, |language| {
167 !languages.contains(&language.file_extension)
168 }) {
169 skipped.push(meta.name);
170 continue;
171 }
172
173 // TODO: This creates a worktree per repetition. Ideally these examples should
174 // either be run sequentially on the same worktree, or reuse worktrees when there
175 // are more examples to run than the concurrency limit.
176 for repetition_number in 0..args.repetitions {
177 let example_instance = ExampleInstance::new(
178 thread.clone(),
179 &repos_dir,
180 &run_dir,
181 &worktrees_dir,
182 repetition_number,
183 );
184
185 examples.push(example_instance);
186 }
187 }
188
189 if !skipped.is_empty() {
190 println!("Skipped threads: {}", skipped.join(", "));
191 }
192
193 if examples.is_empty() {
194 eprintln!("Filter matched no examples");
195 return cx.update(|cx| cx.quit());
196 }
197
198 let mut repo_urls = HashSet::default();
199 let mut clone_tasks = Vec::new();
200
201 let max_name_width = examples
202 .iter()
203 .map(|e| e.worktree_name().len())
204 .max()
205 .unwrap_or(0);
206
207 for (i, example_instance) in examples.iter_mut().enumerate() {
208 let color = COLORS[i % COLORS.len()].to_string();
209 example_instance.set_log_prefix_style(&color, max_name_width);
210
211 println!(
212 "{}Logging to: {}",
213 example_instance.log_prefix,
214 example_instance.run_directory.display()
215 );
216
217 let repo_url = example_instance.repo_url();
218 if repo_urls.insert(repo_url.clone()) {
219 let repo_path = example_instance.repo_path.clone();
220
221 if !repo_path.join(".git").is_dir() {
222 println!(
223 "{:<width$} < {}",
224 "↓ Cloning",
225 repo_url,
226 width = max_name_width
227 );
228
229 let git_task = cx.spawn(async move |_cx| {
230 std::fs::create_dir_all(&repo_path)?;
231 run_git(&repo_path, &["init"]).await?;
232 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
233 });
234
235 clone_tasks.push(git_task);
236 } else {
237 println!(
238 "{:<width$} < {}",
239 "✔︎ Already cloned",
240 repo_url,
241 width = max_name_width
242 );
243
244 let actual_origin =
245 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
246 if actual_origin != repo_url {
247 return Err(anyhow!(
248 "remote origin {} does not match expected origin {}",
249 actual_origin,
250 repo_url,
251 ));
252 }
253 }
254 }
255 }
256
257 future::join_all(clone_tasks).await;
258
259 for example_instance in examples.iter_mut() {
260 example_instance.fetch().await?;
261 }
262
263 let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
264 let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
265
266 future::join_all((0..args.concurrency).map(|_| {
267 let app_state = app_state.clone();
268 let model = model.clone();
269 let zed_commit_sha = zed_commit_sha.clone();
270 let zed_branch_name = zed_branch_name.clone();
271 let run_id = run_id.clone();
272 let examples = examples.clone();
273 let results = results_by_example_name.clone();
274 cx.spawn(async move |cx| {
275 loop {
276 let Some(mut example) = examples.borrow_mut().pop_front() else {
277 break;
278 };
279 let result = async {
280 example.setup().await?;
281 let run_output = cx
282 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
283 .await?;
284 let judge_output = judge_example(
285 example.clone(),
286 model.clone(),
287 &zed_commit_sha,
288 &zed_branch_name,
289 &run_id,
290 &run_output,
291 enable_telemetry,
292 cx,
293 )
294 .await;
295 anyhow::Ok((run_output, judge_output))
296 }
297 .await;
298 results
299 .borrow_mut()
300 .entry(example.name.clone())
301 .or_insert(Vec::new())
302 .push((example.clone(), result));
303 }
304 })
305 }))
306 .await;
307
308 print_h1("EVAL RESULTS");
309
310 let mut diff_scores = Vec::new();
311 let mut thread_scores = Vec::new();
312 let mut programmatic_scores = Vec::new();
313 let mut error_count = 0;
314
315 for (example_name, results) in results_by_example_name.borrow_mut().iter_mut() {
316 print_h2(&example_name);
317
318 results.sort_unstable_by_key(|(example, _)| example.repetition);
319 let mut example_cumulative_tool_metrics = ToolMetrics::default();
320
321 let mut table_rows = String::new();
322
323 for (example, result) in results.iter() {
324 match result {
325 Err(err) => {
326 display_error_row(
327 &mut table_rows,
328 example.repetition,
329 err.to_string(),
330 )?;
331 error_count += 1;
332 }
333 Ok((run_output, judge_output)) => {
334 cumulative_tool_metrics.merge(&run_output.tool_metrics);
335 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
336
337 if !run_output.programmatic_assertions.total_count() > 0 {
338 for assertion in &run_output.programmatic_assertions.ran {
339 assertions::display_table_row(
340 &mut table_rows,
341 example.repetition,
342 assertion,
343 )?;
344 }
345
346 programmatic_scores
347 .push(run_output.programmatic_assertions.passed_percentage())
348 }
349
350 if !judge_output.diff.is_empty() {
351 diff_scores.push(judge_output.diff.passed_percentage());
352
353 for assertion in &judge_output.diff.ran {
354 assertions::display_table_row(
355 &mut table_rows,
356 example.repetition,
357 assertion,
358 )?;
359 }
360 }
361
362 if !judge_output.thread.is_empty() {
363 thread_scores.push(judge_output.thread.passed_percentage());
364
365 for assertion in &judge_output.thread.ran {
366 assertions::display_table_row(
367 &mut table_rows,
368 example.repetition,
369 assertion,
370 )?;
371 }
372 }
373 }
374 }
375 }
376
377 if !table_rows.is_empty() {
378 assertions::print_table_header();
379 print!("{}", table_rows);
380
381 assertions::print_table_divider();
382
383 for (example, result) in results.iter() {
384 if let Ok((run_output, judge_output)) = result {
385 assertions::print_table_round_summary(
386 &example.repetition.to_string(),
387 [
388 &run_output.programmatic_assertions,
389 &judge_output.diff,
390 &judge_output.thread,
391 ]
392 .into_iter(),
393 )
394 }
395 }
396
397 assertions::print_table_divider();
398
399 assertions::print_table_round_summary(
400 "avg",
401 results.iter().flat_map(|(_, result)| {
402 result.iter().flat_map(|(run_output, judge_output)| {
403 [
404 &run_output.programmatic_assertions,
405 &judge_output.diff,
406 &judge_output.thread,
407 ]
408 .into_iter()
409 })
410 }),
411 );
412
413 assertions::print_table_footer();
414 }
415
416 if !example_cumulative_tool_metrics.is_empty() {
417 println!("{}", &example_cumulative_tool_metrics);
418 }
419 }
420
421 if results_by_example_name.borrow().len() > 1 {
422 print_h1("AGGREGATE");
423
424 if error_count > 0 {
425 println!("\n{error_count} examples failed to run!");
426 }
427
428 let programmatic_score_count = programmatic_scores.len();
429 if programmatic_score_count > 0 {
430 let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
431 / (programmatic_score_count as f32))
432 .floor();
433 println!("Average programmatic score: {average_programmatic_score}%");
434 }
435
436 let diff_score_count = diff_scores.len();
437 if diff_score_count > 0 {
438 let average_diff_score =
439 (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
440 println!("Average diff score: {average_diff_score}%");
441 }
442
443 let thread_score_count = thread_scores.len();
444
445 if thread_score_count > 0 {
446 let average_thread_score = (thread_scores.into_iter().sum::<f32>()
447 / (thread_score_count as f32))
448 .floor();
449 println!("Average thread score: {average_thread_score}%");
450 }
451
452 println!("");
453
454 print_h2("CUMULATIVE TOOL METRICS");
455 println!("{}", cumulative_tool_metrics);
456 }
457
458 app_state.client.telemetry().flush_events().await;
459
460 cx.update(|cx| cx.quit())
461 })
462 .detach_and_log_err(cx);
463 });
464}
465
466/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
467pub struct AgentAppState {
468 pub languages: Arc<LanguageRegistry>,
469 pub client: Arc<Client>,
470 pub user_store: Entity<UserStore>,
471 pub fs: Arc<dyn fs::Fs>,
472 pub node_runtime: NodeRuntime,
473
474 // Additional fields not present in `workspace::AppState`.
475 pub prompt_builder: Arc<PromptBuilder>,
476}
477
478pub fn init(cx: &mut App) -> Arc<AgentAppState> {
479 release_channel::init(SemanticVersion::default(), cx);
480 gpui_tokio::init(cx);
481
482 let mut settings_store = SettingsStore::new(cx);
483 settings_store
484 .set_default_settings(settings::default_settings().as_ref(), cx)
485 .unwrap();
486 cx.set_global(settings_store);
487 client::init_settings(cx);
488
489 // Set User-Agent so we can download language servers from GitHub
490 let user_agent = format!(
491 "Zed/{} ({}; {})",
492 AppVersion::global(cx),
493 std::env::consts::OS,
494 std::env::consts::ARCH
495 );
496 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
497 let proxy_url = proxy_str
498 .as_ref()
499 .and_then(|input| input.parse::<Uri>().ok())
500 .or_else(read_proxy_from_env);
501 let http = {
502 let _guard = Tokio::handle(cx).enter();
503
504 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
505 .expect("could not start HTTP client")
506 };
507 cx.set_http_client(Arc::new(http));
508
509 Project::init_settings(cx);
510
511 let client = Client::production(cx);
512 cx.set_http_client(client.http_client().clone());
513
514 let git_binary_path = None;
515 let fs = Arc::new(RealFs::new(
516 git_binary_path,
517 cx.background_executor().clone(),
518 ));
519
520 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
521 languages.set_language_server_download_dir(paths::languages_dir().clone());
522 let languages = Arc::new(languages);
523
524 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
525
526 extension::init(cx);
527
528 let (tx, rx) = async_watch::channel(None);
529 cx.observe_global::<SettingsStore>(move |cx| {
530 let settings = &ProjectSettings::get_global(cx).node;
531 let options = NodeBinaryOptions {
532 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
533 allow_binary_download: true,
534 use_paths: settings.path.as_ref().map(|node_path| {
535 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
536 let npm_path = settings
537 .npm_path
538 .as_ref()
539 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
540 (
541 node_path.clone(),
542 npm_path.unwrap_or_else(|| {
543 let base_path = PathBuf::new();
544 node_path.parent().unwrap_or(&base_path).join("npm")
545 }),
546 )
547 }),
548 };
549 tx.send(Some(options)).log_err();
550 })
551 .detach();
552 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
553
554 let extension_host_proxy = ExtensionHostProxy::global(cx);
555
556 language::init(cx);
557 language_extension::init(extension_host_proxy.clone(), languages.clone());
558 language_model::init(client.clone(), cx);
559 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
560 languages::init(languages.clone(), node_runtime.clone(), cx);
561 assistant_tools::init(client.http_client().clone(), cx);
562 context_server::init(cx);
563 prompt_store::init(cx);
564 let stdout_is_a_pty = false;
565 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
566 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
567
568 SettingsStore::update_global(cx, |store, cx| {
569 store.set_user_settings(include_str!("../runner_settings.json"), cx)
570 })
571 .unwrap();
572
573 Arc::new(AgentAppState {
574 languages,
575 client,
576 user_store,
577 fs,
578 node_runtime,
579 prompt_builder,
580 })
581}
582
583pub fn find_model(
584 model_name: &str,
585 model_registry: &LanguageModelRegistry,
586 cx: &App,
587) -> anyhow::Result<Arc<dyn LanguageModel>> {
588 let model = model_registry
589 .available_models(cx)
590 .find(|model| model.id().0 == model_name);
591
592 let Some(model) = model else {
593 return Err(anyhow!(
594 "No language model named {} was available. Available models: {}",
595 model_name,
596 model_registry
597 .available_models(cx)
598 .map(|model| model.id().0.clone())
599 .collect::<Vec<_>>()
600 .join(", ")
601 ));
602 };
603
604 Ok(model)
605}
606
607pub fn commit_sha_for_path(repo_path: &Path) -> String {
608 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
609}
610
611pub fn git_branch_for_path(repo_path: &Path) -> String {
612 match std::env::var("GITHUB_REF_NAME") {
613 Ok(branch) => branch,
614 Err(_) => {
615 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
616 .unwrap_or_else(|_| "unknown".to_string())
617 }
618 }
619}
620
621async fn judge_example(
622 example: ExampleInstance,
623 model: Arc<dyn LanguageModel>,
624 zed_commit_sha: &str,
625 zed_branch_name: &str,
626 run_id: &str,
627 run_output: &RunOutput,
628 enable_telemetry: bool,
629 cx: &AsyncApp,
630) -> JudgeOutput {
631 let judge_output = example.judge(model.clone(), &run_output, cx).await;
632
633 if enable_telemetry {
634 telemetry::event!(
635 "Agent Example Evaluated",
636 zed_commit_sha = zed_commit_sha,
637 zed_branch_name = zed_branch_name,
638 run_id = run_id,
639 example_name = example.name.clone(),
640 example_repetition = example.repetition,
641 diff_evaluation = judge_output.diff.clone(),
642 thread_evaluation = judge_output.thread.clone(),
643 tool_metrics = run_output.tool_metrics,
644 response_count = run_output.response_count,
645 token_usage = run_output.token_usage,
646 model = model.telemetry_id(),
647 model_provider = model.provider_id().to_string(),
648 repository_url = example.repo_url(),
649 repository_revision = example.revision(),
650 diagnostic_summary_before = run_output.diagnostic_summary_before,
651 diagnostic_summary_after = run_output.diagnostic_summary_after,
652 diagnostics_before = run_output.diagnostics_before,
653 diagnostics_after = run_output.diagnostics_after,
654 );
655 }
656
657 judge_output
658}
659
660const HEADER_WIDTH: usize = 65;
661
662fn print_h1(header: &str) {
663 println!("\n\n{:=^HEADER_WIDTH$}", "");
664 println!("{:^HEADER_WIDTH$}", header);
665 println!("{:=^HEADER_WIDTH$}\n", "");
666}
667
668fn print_h2(header: &str) {
669 println!("\n{:-^HEADER_WIDTH$}", "");
670 println!("{:^HEADER_WIDTH$}", header);
671 println!("{:-^HEADER_WIDTH$}\n", "");
672}