1mod example;
2mod ids;
3mod tool_metrics;
4
5pub(crate) use example::*;
6pub(crate) use tool_metrics::*;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use client::{Client, ProxySettings, UserStore};
12use collections::HashSet;
13use extension::ExtensionHostProxy;
14use futures::{StreamExt, future};
15use gpui::http_client::{Uri, read_proxy_from_env};
16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
17use gpui_tokio::Tokio;
18use language::LanguageRegistry;
19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
20use node_runtime::{NodeBinaryOptions, NodeRuntime};
21use project::Project;
22use project::project_settings::ProjectSettings;
23use prompt_store::PromptBuilder;
24use release_channel::AppVersion;
25use reqwest_client::ReqwestClient;
26use settings::{Settings, SettingsStore};
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use std::usize;
30use util::ResultExt as _;
31
32pub const RUNS_DIR: &str = "./crates/eval/runs";
33
34#[derive(Parser, Debug)]
35#[command(name = "eval", disable_version_flag = true)]
36struct Args {
37 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
38 #[arg(value_name = "EXAMPLE_SUBSTRING")]
39 examples: Vec<String>,
40 /// Model to use (default: "claude-3-7-sonnet-latest")
41 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
42 model: String,
43 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
44 languages: Vec<String>,
45 /// How many times to run each example. Note that this is currently not very efficient as N
46 /// worktrees will be created for the examples.
47 #[arg(long, default_value = "1")]
48 repetitions: u32,
49 /// How many times to run the judge on each example run.
50 #[arg(long, default_value = "3")]
51 judge_repetitions: u32,
52 /// Maximum number of examples to run concurrently.
53 #[arg(long, default_value = "10")]
54 concurrency: usize,
55}
56
57fn main() {
58 env_logger::init();
59
60 let args = Args::parse();
61 let all_available_examples = list_all_examples().unwrap();
62
63 let example_paths = all_available_examples
64 .iter()
65 .filter_map(|example_path| {
66 let name = example_path.file_name()?.to_string_lossy();
67 if args.examples.is_empty()
68 || args
69 .examples
70 .iter()
71 .any(|name_substring| name.contains(name_substring))
72 {
73 Some(example_path.clone())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<_>>();
79
80 let http_client = Arc::new(ReqwestClient::new());
81 let app = Application::headless().with_http_client(http_client.clone());
82
83 app.run(move |cx| {
84 let app_state = init(cx);
85
86 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
87 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
88 let session_id = uuid::Uuid::new_v4().to_string();
89
90 app_state
91 .client
92 .telemetry()
93 .start(system_id, installation_id, session_id, cx);
94
95 let mut cumulative_tool_metrics = ToolMetrics::default();
96
97 let model_registry = LanguageModelRegistry::read_global(cx);
98 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
99 let model_provider_id = model.provider_id();
100 let model_provider = model_registry.provider(&model_provider_id).unwrap();
101
102 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
103 registry.set_default_model(
104 Some(ConfiguredModel {
105 provider: model_provider.clone(),
106 model: model.clone(),
107 }),
108 cx,
109 );
110 });
111
112 let authenticate_task = model_provider.authenticate(cx);
113
114 cx.spawn(async move |cx| {
115 authenticate_task.await.unwrap();
116
117 std::fs::create_dir_all(REPOS_DIR)?;
118 std::fs::create_dir_all(WORKTREES_DIR)?;
119
120 let run_dir = Path::new(RUNS_DIR).join(format!(
121 "{}",
122 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
123 ));
124 std::fs::create_dir_all(&run_dir)?;
125
126 let mut examples = Vec::new();
127
128 const COLORS: [&str; 12] = [
129 "\x1b[31m", // Red
130 "\x1b[32m", // Green
131 "\x1b[33m", // Yellow
132 "\x1b[34m", // Blue
133 "\x1b[35m", // Magenta
134 "\x1b[36m", // Cyan
135 "\x1b[91m", // Bright Red
136 "\x1b[92m", // Bright Green
137 "\x1b[93m", // Bright Yellow
138 "\x1b[94m", // Bright Blue
139 "\x1b[95m", // Bright Magenta
140 "\x1b[96m", // Bright Cyan
141 ];
142
143 let mut max_name_width = 0;
144 let mut skipped = Vec::new();
145
146 for example_path in &example_paths {
147 let example = Example::load_from_directory(example_path, &run_dir)?;
148
149 if !example
150 .base
151 .language_extension
152 .as_ref()
153 .map_or(false, |lang| args.languages.contains(lang))
154 {
155 skipped.push(example.name);
156 continue;
157 }
158
159 // TODO: This creates a worktree per repetition. Ideally these examples should
160 // either be run sequentially on the same worktree, or reuse worktrees when there
161 // are more examples to run than the concurrency limit.
162 for repetition_number in 0..args.repetitions {
163 let mut example = example.clone();
164 example.set_repetition_number(repetition_number);
165
166 let name_len = example.name.len();
167 if name_len > max_name_width {
168 max_name_width = example.name.len();
169 }
170
171 examples.push(example);
172 }
173 }
174
175 println!("Skipped examples: {}\n", skipped.join(", "));
176
177 if examples.is_empty() {
178 eprintln!("Filter matched no examples");
179 return cx.update(|cx| cx.quit());
180 }
181
182 let mut repo_urls = HashSet::default();
183 let mut clone_tasks = Vec::new();
184
185 for (i, example) in examples.iter_mut().enumerate() {
186 let color = COLORS[i % COLORS.len()].to_string();
187 example.set_log_prefix_style(&color, max_name_width);
188
189 println!(
190 "{}Logging to: {}",
191 example.log_prefix,
192 example.example_output_directory().display()
193 );
194
195 let repo_url = example.base.url.clone();
196 if repo_urls.insert(repo_url.clone()) {
197 let repo_path = repo_path_for_url(&repo_url);
198
199 if !repo_path.join(".git").is_dir() {
200 println!(
201 "{:<width$} < {}",
202 "↓ Cloning",
203 repo_url,
204 width = max_name_width
205 );
206
207 let git_task = cx.spawn(async move |_cx| {
208 std::fs::create_dir_all(&repo_path)?;
209 run_git(&repo_path, &["init"]).await?;
210 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
211 });
212
213 clone_tasks.push(git_task);
214 } else {
215 println!(
216 "{:<width$} < {}",
217 "✔︎ Already cloned",
218 repo_url,
219 width = max_name_width
220 );
221
222 let actual_origin =
223 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
224 if actual_origin != repo_url {
225 return Err(anyhow!(
226 "remote origin {} does not match expected origin {}",
227 actual_origin,
228 repo_url,
229 ));
230 }
231 }
232 }
233 }
234
235 future::join_all(clone_tasks).await;
236
237 for example in examples.iter_mut() {
238 example.setup().await?;
239 }
240
241 let judge_repetitions = args.judge_repetitions;
242 let concurrency = args.concurrency;
243
244 let tasks = examples.iter().map(|example| {
245 let app_state = app_state.clone();
246 let model = model.clone();
247 let example = example.clone();
248 cx.spawn(async move |cx| {
249 let result = async {
250 let run_output = cx
251 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
252 .await?;
253 let judge_tasks = (0..judge_repetitions).map(|round| {
254 run_judge_repetition(
255 example.clone(),
256 model.clone(),
257 &run_output,
258 round,
259 cx,
260 )
261 });
262 let judge_outputs = future::join_all(judge_tasks).await;
263 anyhow::Ok((run_output, judge_outputs))
264 }
265 .await;
266 (example, result)
267 })
268 });
269
270 let results = futures::stream::iter(tasks)
271 .buffer_unordered(concurrency)
272 .collect::<Vec<_>>()
273 .await;
274
275 println!("\n\n");
276 print_header("EVAL RESULTS");
277
278 let mut diff_scores = Vec::new();
279 let mut thread_scores = Vec::new();
280 let mut error_count = 0;
281
282 for (example, result) in results {
283 print_header(&example.name);
284
285 match result {
286 Err(err) => {
287 println!("💥 {}{:?}", example.log_prefix, err);
288 error_count += 1;
289 }
290 Ok((run_output, judge_results)) => {
291 cumulative_tool_metrics.merge(&run_output.tool_metrics);
292
293 println!("┌───────┬──────┬────────┐");
294 println!("│ Judge │ Diff │ Thread │");
295 println!("├───────┼──────┼────────┤");
296
297 for (i, judge_result) in judge_results.iter().enumerate() {
298 match judge_result {
299 Ok(judge_output) => {
300 let diff_score = judge_output.diff.score;
301 diff_scores.push(diff_score);
302
303 let thread_display = if let Some(thread) = &judge_output.thread
304 {
305 let thread_score = thread.score;
306 thread_scores.push(thread_score);
307 format!("{}", thread_score)
308 } else {
309 "N/A".to_string()
310 };
311
312 println!(
313 "|{:^7}│{:^6}│{:^8}│",
314 i + 1,
315 diff_score,
316 thread_display
317 );
318 }
319 Err(err) => {
320 println!("|{:^7}│{:^6}│{:^8}│{:?}", i + 1, "N/A", "N/A", err);
321 }
322 }
323 }
324
325 println!("└───────┴──────┴────────┘");
326
327 println!("{}", run_output.tool_metrics);
328 }
329 }
330 println!(
331 "{} > {}",
332 " ".repeat(max_name_width),
333 example.example_output_directory().display()
334 );
335 }
336
337 let diff_score_count = diff_scores.len();
338 let average_diff_score = diff_scores
339 .into_iter()
340 .map(|score| score as f32)
341 .sum::<f32>()
342 / (diff_score_count as f32);
343
344 if error_count > 0 {
345 println!("\n{error_count} examples failed to run!");
346 }
347
348 if diff_score_count > 0 {
349 println!("\nAverage code diff score: {average_diff_score}");
350 }
351
352 let thread_score_count = thread_scores.len();
353
354 // We might have gotten no thread scores if we weren't asked to judge the thread.
355 if thread_score_count > 0 {
356 let average_thread_score = thread_scores
357 .into_iter()
358 .map(|score| score as f32)
359 .sum::<f32>()
360 / (thread_score_count as f32);
361
362 if diff_score_count > 0 {
363 println!("\nAverage thread score: {average_thread_score}");
364 }
365 }
366
367 print_header("CUMULATIVE TOOL METRICS");
368 println!("{}", cumulative_tool_metrics);
369
370 std::thread::sleep(std::time::Duration::from_secs(2));
371
372 app_state.client.telemetry().flush_events();
373
374 cx.update(|cx| cx.quit())
375 })
376 .detach_and_log_err(cx);
377 });
378}
379
380fn list_all_examples() -> Result<Vec<PathBuf>> {
381 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
382 let entries = std::fs::read_dir(path).unwrap();
383 let mut result_paths = Vec::new();
384 for entry in entries {
385 let entry = entry?;
386 let path = entry.path();
387 if path.is_dir() {
388 result_paths.push(path);
389 }
390 }
391 Ok(result_paths)
392}
393
394/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
395pub struct AgentAppState {
396 pub languages: Arc<LanguageRegistry>,
397 pub client: Arc<Client>,
398 pub user_store: Entity<UserStore>,
399 pub fs: Arc<dyn fs::Fs>,
400 pub node_runtime: NodeRuntime,
401
402 // Additional fields not present in `workspace::AppState`.
403 pub prompt_builder: Arc<PromptBuilder>,
404}
405
406pub fn init(cx: &mut App) -> Arc<AgentAppState> {
407 release_channel::init(SemanticVersion::default(), cx);
408 gpui_tokio::init(cx);
409
410 let mut settings_store = SettingsStore::new(cx);
411 settings_store
412 .set_default_settings(settings::default_settings().as_ref(), cx)
413 .unwrap();
414 cx.set_global(settings_store);
415 client::init_settings(cx);
416
417 // Set User-Agent so we can download language servers from GitHub
418 let user_agent = format!(
419 "Zed/{} ({}; {})",
420 AppVersion::global(cx),
421 std::env::consts::OS,
422 std::env::consts::ARCH
423 );
424 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
425 let proxy_url = proxy_str
426 .as_ref()
427 .and_then(|input| input.parse::<Uri>().ok())
428 .or_else(read_proxy_from_env);
429 let http = {
430 let _guard = Tokio::handle(cx).enter();
431
432 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
433 .expect("could not start HTTP client")
434 };
435 cx.set_http_client(Arc::new(http));
436
437 Project::init_settings(cx);
438
439 let client = Client::production(cx);
440 cx.set_http_client(client.http_client().clone());
441
442 let git_binary_path = None;
443 let fs = Arc::new(RealFs::new(
444 git_binary_path,
445 cx.background_executor().clone(),
446 ));
447
448 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
449 languages.set_language_server_download_dir(paths::languages_dir().clone());
450 let languages = Arc::new(languages);
451
452 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
453
454 extension::init(cx);
455
456 let (tx, rx) = async_watch::channel(None);
457 cx.observe_global::<SettingsStore>(move |cx| {
458 let settings = &ProjectSettings::get_global(cx).node;
459 let options = NodeBinaryOptions {
460 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
461 allow_binary_download: true,
462 use_paths: settings.path.as_ref().map(|node_path| {
463 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
464 let npm_path = settings
465 .npm_path
466 .as_ref()
467 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
468 (
469 node_path.clone(),
470 npm_path.unwrap_or_else(|| {
471 let base_path = PathBuf::new();
472 node_path.parent().unwrap_or(&base_path).join("npm")
473 }),
474 )
475 }),
476 };
477 tx.send(Some(options)).log_err();
478 })
479 .detach();
480 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
481
482 let extension_host_proxy = ExtensionHostProxy::global(cx);
483
484 language::init(cx);
485 language_extension::init(extension_host_proxy.clone(), languages.clone());
486 language_model::init(client.clone(), cx);
487 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
488 languages::init(languages.clone(), node_runtime.clone(), cx);
489 assistant_tools::init(client.http_client().clone(), cx);
490 context_server::init(cx);
491 prompt_store::init(cx);
492 let stdout_is_a_pty = false;
493 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
494 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
495
496 SettingsStore::update_global(cx, |store, cx| {
497 store.set_user_settings(include_str!("../runner_settings.json"), cx)
498 })
499 .unwrap();
500
501 Arc::new(AgentAppState {
502 languages,
503 client,
504 user_store,
505 fs,
506 node_runtime,
507 prompt_builder,
508 })
509}
510
511pub fn find_model(
512 model_name: &str,
513 model_registry: &LanguageModelRegistry,
514 cx: &App,
515) -> anyhow::Result<Arc<dyn LanguageModel>> {
516 let model = model_registry
517 .available_models(cx)
518 .find(|model| model.id().0 == model_name);
519
520 let Some(model) = model else {
521 return Err(anyhow!(
522 "No language model named {} was available. Available models: {}",
523 model_name,
524 model_registry
525 .available_models(cx)
526 .map(|model| model.id().0.clone())
527 .collect::<Vec<_>>()
528 .join(", ")
529 ));
530 };
531
532 Ok(model)
533}
534
535pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
536 (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
537}
538
539pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
540 futures::executor::block_on(async {
541 get_current_commit_id(repo_path).await.unwrap_or_default()
542 })
543}
544
545async fn run_judge_repetition(
546 example: Example,
547 model: Arc<dyn LanguageModel>,
548 run_output: &RunOutput,
549 round: u32,
550 cx: &AsyncApp,
551) -> Result<JudgeOutput> {
552 let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
553
554 if let Ok(judge_output) = &judge_result {
555 let cohort_id = example
556 .run_directory_path
557 .file_name()
558 .map(|name| name.to_string_lossy().to_string())
559 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
560
561 let path = std::path::Path::new(".");
562 let commit_id = get_current_commit_id(path).await.unwrap_or_default();
563
564 if let Some(thread) = &judge_output.thread {
565 telemetry::event!(
566 "Agent Eval Completed",
567 cohort_id = cohort_id,
568 example_name = example.name.clone(),
569 round = round,
570 diff_score = judge_output.diff.score,
571 diff_analysis = judge_output.diff.analysis,
572 thread_score = thread.score,
573 thread_analysis = thread.analysis,
574 tool_metrics = run_output.tool_metrics,
575 response_count = run_output.response_count,
576 token_usage = run_output.token_usage,
577 model = model.telemetry_id(),
578 model_provider = model.provider_id().to_string(),
579 repository_url = example.base.url.clone(),
580 repository_revision = example.base.revision.clone(),
581 diagnostics_before = run_output.diagnostics_before,
582 diagnostics_after = run_output.diagnostics_after,
583 commit_id = commit_id
584 );
585 } else {
586 telemetry::event!(
587 "Agent Eval Completed",
588 cohort_id = cohort_id,
589 example_name = example.name.clone(),
590 round = round,
591 diff_score = judge_output.diff.score,
592 diff_analysis = judge_output.diff.analysis,
593 tool_metrics = run_output.tool_metrics,
594 response_count = run_output.response_count,
595 token_usage = run_output.token_usage,
596 model = model.telemetry_id(),
597 model_provider = model.provider_id().to_string(),
598 repository_url = example.base.url.clone(),
599 repository_revision = example.base.revision.clone(),
600 diagnostics_before = run_output.diagnostics_before,
601 diagnostics_after = run_output.diagnostics_after,
602 commit_id = commit_id
603 );
604 }
605 }
606
607 judge_result
608}
609
610fn print_header(header: &str) {
611 println!("\n========================================");
612 println!("{:^40}", header);
613 println!("========================================\n");
614}