1mod example;
2mod ids;
3mod tool_metrics;
4
5pub(crate) use example::*;
6pub(crate) use tool_metrics::*;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use client::{Client, ProxySettings, UserStore};
12use collections::HashSet;
13use extension::ExtensionHostProxy;
14use futures::{StreamExt, future};
15use gpui::http_client::{Uri, read_proxy_from_env};
16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
17use gpui_tokio::Tokio;
18use language::LanguageRegistry;
19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
20use node_runtime::{NodeBinaryOptions, NodeRuntime};
21use project::Project;
22use project::project_settings::ProjectSettings;
23use prompt_store::PromptBuilder;
24use release_channel::AppVersion;
25use reqwest_client::ReqwestClient;
26use settings::{Settings, SettingsStore};
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use std::usize;
30use util::ResultExt as _;
31
32pub const RUNS_DIR: &str = "./crates/eval/runs";
33
34#[derive(Parser, Debug)]
35#[command(name = "eval", disable_version_flag = true)]
36struct Args {
37 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
38 #[arg(value_name = "EXAMPLE_SUBSTRING")]
39 examples: Vec<String>,
40 /// Model to use (default: "claude-3-7-sonnet-latest")
41 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
42 model: String,
43 #[arg(long, value_delimiter = ',')]
44 languages: Option<Vec<String>>,
45 /// How many times to run each example. Note that this is currently not very efficient as N
46 /// worktrees will be created for the examples.
47 #[arg(long, default_value = "1")]
48 repetitions: u32,
49 /// How many times to run the judge on each example run.
50 #[arg(long, default_value = "3")]
51 judge_repetitions: u32,
52 /// Maximum number of examples to run concurrently.
53 #[arg(long, default_value = "10")]
54 concurrency: usize,
55}
56
57fn main() {
58 env_logger::init();
59
60 let args = Args::parse();
61 let all_available_examples = list_all_examples().unwrap();
62 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
63
64 let example_paths = all_available_examples
65 .iter()
66 .filter_map(|example_path| {
67 let name = example_path.file_name()?.to_string_lossy();
68 if args.examples.is_empty()
69 || args
70 .examples
71 .iter()
72 .any(|name_substring| name.contains(name_substring))
73 {
74 Some(example_path.clone())
75 } else {
76 None
77 }
78 })
79 .collect::<Vec<_>>();
80
81 let http_client = Arc::new(ReqwestClient::new());
82 let app = Application::headless().with_http_client(http_client.clone());
83
84 app.run(move |cx| {
85 let app_state = init(cx);
86
87 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
88 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
89 let session_id = uuid::Uuid::new_v4().to_string();
90
91 app_state
92 .client
93 .telemetry()
94 .start(system_id, installation_id, session_id, cx);
95
96 let mut cumulative_tool_metrics = ToolMetrics::default();
97
98 let model_registry = LanguageModelRegistry::read_global(cx);
99 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
100 let model_provider_id = model.provider_id();
101 let model_provider = model_registry.provider(&model_provider_id).unwrap();
102
103 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
104 registry.set_default_model(
105 Some(ConfiguredModel {
106 provider: model_provider.clone(),
107 model: model.clone(),
108 }),
109 cx,
110 );
111 });
112
113 let authenticate_task = model_provider.authenticate(cx);
114
115 cx.spawn(async move |cx| {
116 authenticate_task.await.unwrap();
117
118 std::fs::create_dir_all(REPOS_DIR)?;
119 std::fs::create_dir_all(WORKTREES_DIR)?;
120
121 let run_dir = Path::new(RUNS_DIR).join(format!(
122 "{}",
123 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
124 ));
125 std::fs::create_dir_all(&run_dir)?;
126
127 let mut examples = Vec::new();
128
129 const COLORS: [&str; 12] = [
130 "\x1b[31m", // Red
131 "\x1b[32m", // Green
132 "\x1b[33m", // Yellow
133 "\x1b[34m", // Blue
134 "\x1b[35m", // Magenta
135 "\x1b[36m", // Cyan
136 "\x1b[91m", // Bright Red
137 "\x1b[92m", // Bright Green
138 "\x1b[93m", // Bright Yellow
139 "\x1b[94m", // Bright Blue
140 "\x1b[95m", // Bright Magenta
141 "\x1b[96m", // Bright Cyan
142 ];
143
144 let mut max_name_width = 0;
145 let mut skipped = Vec::new();
146
147 for example_path in &example_paths {
148 let example = Example::load_from_directory(example_path, &run_dir)?;
149
150 if !example
151 .base
152 .language_extension
153 .as_ref()
154 .map_or(false, |lang| languages.contains(lang))
155 {
156 skipped.push(example.name);
157 continue;
158 }
159
160 // TODO: This creates a worktree per repetition. Ideally these examples should
161 // either be run sequentially on the same worktree, or reuse worktrees when there
162 // are more examples to run than the concurrency limit.
163 for repetition_number in 0..args.repetitions {
164 let mut example = example.clone();
165 example.set_repetition_number(repetition_number);
166
167 let name_len = example.name.len();
168 if name_len > max_name_width {
169 max_name_width = example.name.len();
170 }
171
172 examples.push(example);
173 }
174 }
175
176 println!("Skipped examples: {}\n", skipped.join(", "));
177
178 if examples.is_empty() {
179 eprintln!("Filter matched no examples");
180 return cx.update(|cx| cx.quit());
181 }
182
183 let mut repo_urls = HashSet::default();
184 let mut clone_tasks = Vec::new();
185
186 for (i, example) in examples.iter_mut().enumerate() {
187 let color = COLORS[i % COLORS.len()].to_string();
188 example.set_log_prefix_style(&color, max_name_width);
189
190 println!(
191 "{}Logging to: {}",
192 example.log_prefix,
193 example.example_output_directory().display()
194 );
195
196 let repo_url = example.base.url.clone();
197 if repo_urls.insert(repo_url.clone()) {
198 let repo_path = repo_path_for_url(&repo_url);
199
200 if !repo_path.join(".git").is_dir() {
201 println!(
202 "{:<width$} < {}",
203 "↓ Cloning",
204 repo_url,
205 width = max_name_width
206 );
207
208 let git_task = cx.spawn(async move |_cx| {
209 std::fs::create_dir_all(&repo_path)?;
210 run_git(&repo_path, &["init"]).await?;
211 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
212 });
213
214 clone_tasks.push(git_task);
215 } else {
216 println!(
217 "{:<width$} < {}",
218 "✔︎ Already cloned",
219 repo_url,
220 width = max_name_width
221 );
222
223 let actual_origin =
224 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
225 if actual_origin != repo_url {
226 return Err(anyhow!(
227 "remote origin {} does not match expected origin {}",
228 actual_origin,
229 repo_url,
230 ));
231 }
232 }
233 }
234 }
235
236 future::join_all(clone_tasks).await;
237
238 for example in examples.iter_mut() {
239 example.setup().await?;
240 }
241
242 let judge_repetitions = args.judge_repetitions;
243 let concurrency = args.concurrency;
244
245 let tasks = examples.iter().map(|example| {
246 let app_state = app_state.clone();
247 let model = model.clone();
248 let example = example.clone();
249 cx.spawn(async move |cx| {
250 let result = async {
251 let run_output = cx
252 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
253 .await?;
254 let judge_tasks = (0..judge_repetitions).map(|round| {
255 run_judge_repetition(
256 example.clone(),
257 model.clone(),
258 &run_output,
259 round,
260 cx,
261 )
262 });
263 let judge_outputs = future::join_all(judge_tasks).await;
264 anyhow::Ok((run_output, judge_outputs))
265 }
266 .await;
267 (example, result)
268 })
269 });
270
271 let results = futures::stream::iter(tasks)
272 .buffer_unordered(concurrency)
273 .collect::<Vec<_>>()
274 .await;
275
276 println!("\n\n");
277 print_header("EVAL RESULTS");
278
279 let mut diff_scores = Vec::new();
280 let mut thread_scores = Vec::new();
281 let mut error_count = 0;
282
283 for (example, result) in results {
284 print_header(&example.name);
285
286 match result {
287 Err(err) => {
288 println!("💥 {}{:?}", example.log_prefix, err);
289 error_count += 1;
290 }
291 Ok((run_output, judge_results)) => {
292 cumulative_tool_metrics.merge(&run_output.tool_metrics);
293
294 println!("┌───────┬──────┬────────┐");
295 println!("│ Judge │ Diff │ Thread │");
296 println!("├───────┼──────┼────────┤");
297
298 for (i, judge_result) in judge_results.iter().enumerate() {
299 match judge_result {
300 Ok(judge_output) => {
301 let diff_score = judge_output.diff.score;
302 diff_scores.push(diff_score);
303
304 let thread_display = if let Some(thread) = &judge_output.thread
305 {
306 let thread_score = thread.score;
307 thread_scores.push(thread_score);
308 format!("{}", thread_score)
309 } else {
310 "N/A".to_string()
311 };
312
313 println!(
314 "|{:^7}│{:^6}│{:^8}│",
315 i + 1,
316 diff_score,
317 thread_display
318 );
319 }
320 Err(err) => {
321 println!("|{:^7}│{:^6}│{:^8}│{:?}", i + 1, "N/A", "N/A", err);
322 }
323 }
324 }
325
326 println!("└───────┴──────┴────────┘");
327
328 println!("{}", run_output.tool_metrics);
329 }
330 }
331 println!(
332 "{} > {}",
333 " ".repeat(max_name_width),
334 example.example_output_directory().display()
335 );
336 }
337
338 let diff_score_count = diff_scores.len();
339 let average_diff_score = diff_scores
340 .into_iter()
341 .map(|score| score as f32)
342 .sum::<f32>()
343 / (diff_score_count as f32);
344
345 if error_count > 0 {
346 println!("\n{error_count} examples failed to run!");
347 }
348
349 if diff_score_count > 0 {
350 println!("\nAverage code diff score: {average_diff_score}");
351 }
352
353 let thread_score_count = thread_scores.len();
354
355 // We might have gotten no thread scores if we weren't asked to judge the thread.
356 if thread_score_count > 0 {
357 let average_thread_score = thread_scores
358 .into_iter()
359 .map(|score| score as f32)
360 .sum::<f32>()
361 / (thread_score_count as f32);
362
363 if diff_score_count > 0 {
364 println!("\nAverage thread score: {average_thread_score}");
365 }
366 }
367
368 print_header("CUMULATIVE TOOL METRICS");
369 println!("{}", cumulative_tool_metrics);
370
371 std::thread::sleep(std::time::Duration::from_secs(2));
372
373 app_state.client.telemetry().flush_events();
374
375 cx.update(|cx| cx.quit())
376 })
377 .detach_and_log_err(cx);
378 });
379}
380
381fn list_all_examples() -> Result<Vec<PathBuf>> {
382 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
383 let entries = std::fs::read_dir(path).unwrap();
384 let mut result_paths = Vec::new();
385 for entry in entries {
386 let entry = entry?;
387 let path = entry.path();
388 if path.is_dir() {
389 result_paths.push(path);
390 }
391 }
392 Ok(result_paths)
393}
394
395/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
396pub struct AgentAppState {
397 pub languages: Arc<LanguageRegistry>,
398 pub client: Arc<Client>,
399 pub user_store: Entity<UserStore>,
400 pub fs: Arc<dyn fs::Fs>,
401 pub node_runtime: NodeRuntime,
402
403 // Additional fields not present in `workspace::AppState`.
404 pub prompt_builder: Arc<PromptBuilder>,
405}
406
407pub fn init(cx: &mut App) -> Arc<AgentAppState> {
408 release_channel::init(SemanticVersion::default(), cx);
409 gpui_tokio::init(cx);
410
411 let mut settings_store = SettingsStore::new(cx);
412 settings_store
413 .set_default_settings(settings::default_settings().as_ref(), cx)
414 .unwrap();
415 cx.set_global(settings_store);
416 client::init_settings(cx);
417
418 // Set User-Agent so we can download language servers from GitHub
419 let user_agent = format!(
420 "Zed/{} ({}; {})",
421 AppVersion::global(cx),
422 std::env::consts::OS,
423 std::env::consts::ARCH
424 );
425 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
426 let proxy_url = proxy_str
427 .as_ref()
428 .and_then(|input| input.parse::<Uri>().ok())
429 .or_else(read_proxy_from_env);
430 let http = {
431 let _guard = Tokio::handle(cx).enter();
432
433 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
434 .expect("could not start HTTP client")
435 };
436 cx.set_http_client(Arc::new(http));
437
438 Project::init_settings(cx);
439
440 let client = Client::production(cx);
441 cx.set_http_client(client.http_client().clone());
442
443 let git_binary_path = None;
444 let fs = Arc::new(RealFs::new(
445 git_binary_path,
446 cx.background_executor().clone(),
447 ));
448
449 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
450 languages.set_language_server_download_dir(paths::languages_dir().clone());
451 let languages = Arc::new(languages);
452
453 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
454
455 extension::init(cx);
456
457 let (tx, rx) = async_watch::channel(None);
458 cx.observe_global::<SettingsStore>(move |cx| {
459 let settings = &ProjectSettings::get_global(cx).node;
460 let options = NodeBinaryOptions {
461 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
462 allow_binary_download: true,
463 use_paths: settings.path.as_ref().map(|node_path| {
464 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
465 let npm_path = settings
466 .npm_path
467 .as_ref()
468 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
469 (
470 node_path.clone(),
471 npm_path.unwrap_or_else(|| {
472 let base_path = PathBuf::new();
473 node_path.parent().unwrap_or(&base_path).join("npm")
474 }),
475 )
476 }),
477 };
478 tx.send(Some(options)).log_err();
479 })
480 .detach();
481 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
482
483 let extension_host_proxy = ExtensionHostProxy::global(cx);
484
485 language::init(cx);
486 language_extension::init(extension_host_proxy.clone(), languages.clone());
487 language_model::init(client.clone(), cx);
488 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
489 languages::init(languages.clone(), node_runtime.clone(), cx);
490 assistant_tools::init(client.http_client().clone(), cx);
491 context_server::init(cx);
492 prompt_store::init(cx);
493 let stdout_is_a_pty = false;
494 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
495 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
496
497 SettingsStore::update_global(cx, |store, cx| {
498 store.set_user_settings(include_str!("../runner_settings.json"), cx)
499 })
500 .unwrap();
501
502 Arc::new(AgentAppState {
503 languages,
504 client,
505 user_store,
506 fs,
507 node_runtime,
508 prompt_builder,
509 })
510}
511
512pub fn find_model(
513 model_name: &str,
514 model_registry: &LanguageModelRegistry,
515 cx: &App,
516) -> anyhow::Result<Arc<dyn LanguageModel>> {
517 let model = model_registry
518 .available_models(cx)
519 .find(|model| model.id().0 == model_name);
520
521 let Some(model) = model else {
522 return Err(anyhow!(
523 "No language model named {} was available. Available models: {}",
524 model_name,
525 model_registry
526 .available_models(cx)
527 .map(|model| model.id().0.clone())
528 .collect::<Vec<_>>()
529 .join(", ")
530 ));
531 };
532
533 Ok(model)
534}
535
536pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
537 (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
538}
539
540pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
541 futures::executor::block_on(async {
542 get_current_commit_id(repo_path).await.unwrap_or_default()
543 })
544}
545
546async fn run_judge_repetition(
547 example: Example,
548 model: Arc<dyn LanguageModel>,
549 run_output: &RunOutput,
550 round: u32,
551 cx: &AsyncApp,
552) -> Result<JudgeOutput> {
553 let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
554
555 if let Ok(judge_output) = &judge_result {
556 let cohort_id = example
557 .run_directory_path
558 .file_name()
559 .map(|name| name.to_string_lossy().to_string())
560 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
561
562 let path = std::path::Path::new(".");
563 let commit_id = get_current_commit_id(path).await.unwrap_or_default();
564
565 if let Some(thread) = &judge_output.thread {
566 telemetry::event!(
567 "Agent Eval Completed",
568 cohort_id = cohort_id,
569 example_name = example.name.clone(),
570 round = round,
571 diff_score = judge_output.diff.score,
572 diff_analysis = judge_output.diff.analysis,
573 thread_score = thread.score,
574 thread_analysis = thread.analysis,
575 tool_metrics = run_output.tool_metrics,
576 response_count = run_output.response_count,
577 token_usage = run_output.token_usage,
578 model = model.telemetry_id(),
579 model_provider = model.provider_id().to_string(),
580 repository_url = example.base.url.clone(),
581 repository_revision = example.base.revision.clone(),
582 diagnostics_before = run_output.diagnostics_before,
583 diagnostics_after = run_output.diagnostics_after,
584 commit_id = commit_id
585 );
586 } else {
587 telemetry::event!(
588 "Agent Eval Completed",
589 cohort_id = cohort_id,
590 example_name = example.name.clone(),
591 round = round,
592 diff_score = judge_output.diff.score,
593 diff_analysis = judge_output.diff.analysis,
594 tool_metrics = run_output.tool_metrics,
595 response_count = run_output.response_count,
596 token_usage = run_output.token_usage,
597 model = model.telemetry_id(),
598 model_provider = model.provider_id().to_string(),
599 repository_url = example.base.url.clone(),
600 repository_revision = example.base.revision.clone(),
601 diagnostics_before = run_output.diagnostics_before,
602 diagnostics_after = run_output.diagnostics_after,
603 commit_id = commit_id
604 );
605 }
606 }
607
608 judge_result
609}
610
611fn print_header(header: &str) {
612 println!("\n========================================");
613 println!("{:^40}", header);
614 println!("========================================\n");
615}