1mod example;
2mod ids;
3mod tool_metrics;
4
5pub(crate) use example::*;
6use parking_lot::Mutex;
7pub(crate) use tool_metrics::*;
8
9use ::fs::RealFs;
10use anyhow::{Result, anyhow};
11use clap::Parser;
12use client::{Client, ProxySettings, UserStore};
13use collections::{HashMap, HashSet};
14use extension::ExtensionHostProxy;
15use futures::future;
16use gpui::http_client::{Uri, read_proxy_from_env};
17use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
18use gpui_tokio::Tokio;
19use language::LanguageRegistry;
20use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
21use node_runtime::{NodeBinaryOptions, NodeRuntime};
22use project::Project;
23use project::project_settings::ProjectSettings;
24use prompt_store::PromptBuilder;
25use release_channel::AppVersion;
26use reqwest_client::ReqwestClient;
27use settings::{Settings, SettingsStore};
28use std::collections::VecDeque;
29use std::env;
30use std::path::{Path, PathBuf};
31use std::sync::Arc;
32use util::ResultExt as _;
33
34#[derive(Parser, Debug)]
35#[command(name = "eval", disable_version_flag = true)]
36struct Args {
37 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
38 #[arg(value_name = "EXAMPLE_SUBSTRING")]
39 examples: Vec<String>,
40 /// Model to use (default: "claude-3-7-sonnet-latest")
41 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
42 model: String,
43 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
44 languages: Vec<String>,
45 /// How many times to run each example.
46 #[arg(long, default_value = "1")]
47 repetitions: usize,
48 /// Maximum number of examples to run concurrently.
49 #[arg(long, default_value = "10")]
50 concurrency: usize,
51}
52
53fn main() {
54 env_logger::init();
55
56 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
57 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
58 let session_id = uuid::Uuid::new_v4().to_string();
59 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
60 let run_id = match env::var("GITHUB_RUN_ID") {
61 Ok(run_id) => format!("github/{}", run_id),
62 Err(_) => format!("local/{}", run_timestamp),
63 };
64
65 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
66 .parent()
67 .unwrap()
68 .parent()
69 .unwrap();
70 let eval_crate_dir = root_dir.join("crates/eval");
71 let repos_dir = eval_crate_dir.join("repos");
72 let worktrees_dir = eval_crate_dir.join("worktrees");
73 let examples_dir = eval_crate_dir.join("examples");
74 let runs_dir = eval_crate_dir.join("runs");
75 let run_dir = runs_dir.join(format!("{}", run_timestamp));
76 std::fs::create_dir_all(&run_dir).unwrap();
77 std::fs::create_dir_all(&repos_dir).unwrap();
78 std::fs::create_dir_all(&worktrees_dir).unwrap();
79 std::fs::create_dir_all(&examples_dir).unwrap();
80 std::fs::create_dir_all(&paths::config_dir()).unwrap();
81
82 let zed_commit_sha = commit_sha_for_path(root_dir);
83 let zed_branch_name = git_branch_for_path(root_dir);
84 let args = Args::parse();
85 let all_available_examples = list_all_examples(&examples_dir).unwrap();
86
87 let example_paths = all_available_examples
88 .iter()
89 .filter_map(|example_path| {
90 let name = example_path.file_name()?.to_string_lossy();
91 if args.examples.is_empty()
92 || args
93 .examples
94 .iter()
95 .any(|name_substring| name.contains(name_substring))
96 {
97 Some(example_path.clone())
98 } else {
99 None
100 }
101 })
102 .collect::<Vec<_>>();
103
104 let http_client = Arc::new(ReqwestClient::new());
105 let app = Application::headless().with_http_client(http_client.clone());
106
107 app.run(move |cx| {
108 let app_state = init(cx);
109
110 let telemetry = app_state.client.telemetry();
111 telemetry.start(system_id, installation_id, session_id, cx);
112
113 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
114 && telemetry.has_checksum_seed();
115 if enable_telemetry {
116 println!("Telemetry enabled");
117 telemetry::event!(
118 "Agent Eval Started",
119 zed_commit_sha = zed_commit_sha,
120 zed_branch_name = zed_branch_name,
121 run_id = run_id,
122 );
123 }
124
125 let mut cumulative_tool_metrics = ToolMetrics::default();
126
127 let model_registry = LanguageModelRegistry::read_global(cx);
128 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
129 let model_provider_id = model.provider_id();
130 let model_provider = model_registry.provider(&model_provider_id).unwrap();
131
132 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
133 registry.set_default_model(
134 Some(ConfiguredModel {
135 provider: model_provider.clone(),
136 model: model.clone(),
137 }),
138 cx,
139 );
140 });
141
142 let authenticate_task = model_provider.authenticate(cx);
143
144 cx.spawn(async move |cx| {
145 authenticate_task.await.unwrap();
146
147 let mut examples = Vec::new();
148
149 const COLORS: [&str; 12] = [
150 "\x1b[31m", // Red
151 "\x1b[32m", // Green
152 "\x1b[33m", // Yellow
153 "\x1b[34m", // Blue
154 "\x1b[35m", // Magenta
155 "\x1b[36m", // Cyan
156 "\x1b[91m", // Bright Red
157 "\x1b[92m", // Bright Green
158 "\x1b[93m", // Bright Yellow
159 "\x1b[94m", // Bright Blue
160 "\x1b[95m", // Bright Magenta
161 "\x1b[96m", // Bright Cyan
162 ];
163
164 let mut skipped = Vec::new();
165
166 for example_path in &example_paths {
167 let example = Example::load_from_directory(
168 example_path,
169 &run_dir,
170 &worktrees_dir,
171 &repos_dir,
172 )?;
173
174 if !example
175 .base
176 .language_extension
177 .as_ref()
178 .map_or(false, |lang| args.languages.contains(lang))
179 {
180 skipped.push(example.name);
181 continue;
182 }
183
184 examples.extend(example.repeat(args.repetitions));
185 }
186
187 println!("Skipped examples: {}\n", skipped.join(", "));
188
189 if examples.is_empty() {
190 eprintln!("Filter matched no examples");
191 return cx.update(|cx| cx.quit());
192 }
193
194 let mut repo_urls = HashSet::default();
195 let mut clone_tasks = Vec::new();
196
197 let max_name_width = examples
198 .iter()
199 .map(|e| e.repetition_name().len())
200 .max()
201 .unwrap_or(0);
202 for (i, example) in examples.iter_mut().enumerate() {
203 let color = COLORS[i % COLORS.len()].to_string();
204 example.set_log_prefix_style(&color, max_name_width);
205
206 println!(
207 "{}Logging to: {}",
208 example.log_prefix,
209 example.run_directory_path().display()
210 );
211
212 let repo_url = example.base.url.clone();
213 if repo_urls.insert(repo_url.clone()) {
214 let repo_path = example.repo_path.clone();
215
216 if !repo_path.join(".git").is_dir() {
217 println!(
218 "{:<width$} < {}",
219 "↓ Cloning",
220 repo_url,
221 width = max_name_width
222 );
223
224 let git_task = cx.spawn(async move |_cx| {
225 std::fs::create_dir_all(&repo_path)?;
226 run_git(&repo_path, &["init"]).await?;
227 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
228 });
229
230 clone_tasks.push(git_task);
231 } else {
232 println!(
233 "{:<width$} < {}",
234 "✔︎ Already cloned",
235 repo_url,
236 width = max_name_width
237 );
238
239 let actual_origin =
240 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
241 if actual_origin != repo_url {
242 return Err(anyhow!(
243 "remote origin {} does not match expected origin {}",
244 actual_origin,
245 repo_url,
246 ));
247 }
248 }
249 }
250 }
251
252 future::join_all(clone_tasks).await;
253
254 for example in examples.iter_mut() {
255 example.fetch().await?;
256 }
257
258 let examples = Arc::new(Mutex::new(VecDeque::from(examples)));
259 let results_by_example_name = Arc::new(Mutex::new(HashMap::default()));
260
261 future::join_all((0..args.concurrency).map(|_| {
262 let app_state = app_state.clone();
263 let model = model.clone();
264 let zed_commit_sha = zed_commit_sha.clone();
265 let zed_branch_name = zed_branch_name.clone();
266 let run_id = run_id.clone();
267 let examples = examples.clone();
268 let results = results_by_example_name.clone();
269 cx.spawn(async move |cx| {
270 loop {
271 let Some(mut example) = examples.lock().pop_front() else {
272 break;
273 };
274 let result = async {
275 example.setup().await?;
276 let run_output = cx
277 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
278 .await?;
279 let judge_output = judge_example(
280 example.clone(),
281 model.clone(),
282 &zed_commit_sha,
283 &zed_branch_name,
284 &run_id,
285 &run_output,
286 enable_telemetry,
287 cx,
288 )
289 .await;
290 anyhow::Ok((run_output, judge_output))
291 }
292 .await;
293 results
294 .lock()
295 .entry(example.name.clone())
296 .or_insert(Vec::new())
297 .push((example.clone(), result));
298 }
299 })
300 }))
301 .await;
302
303 println!("\n\n");
304 print_header("EVAL RESULTS");
305
306 let mut diff_scores = Vec::new();
307 let mut thread_scores = Vec::new();
308 let mut error_count = 0;
309
310 for (example_name, results) in results_by_example_name.lock().iter_mut() {
311 print_header(&example_name);
312
313 results.sort_unstable_by_key(|(example, _)| example.repetition);
314 let mut example_cumulative_tool_metrics = ToolMetrics::default();
315
316 println!("┌───────┬──────┬────────┐");
317 println!("│ Round │ Diff │ Thread │");
318 println!("├───────┼──────┼────────┤");
319 for (example, result) in results {
320 let run_dir_path = example.run_directory_path();
321 let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap();
322
323 match result {
324 Err(err) => {
325 println!(
326 "|{:^7}│{:^6}│{:^8}│ {:?}{}",
327 example.repetition,
328 "N/A",
329 "N/A",
330 err,
331 relative_run_dir_path.display()
332 );
333 error_count += 1;
334 }
335 Ok((run_output, judge_result)) => {
336 cumulative_tool_metrics.merge(&run_output.tool_metrics);
337 example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
338
339 match judge_result {
340 Ok(judge_output) => {
341 diff_scores.push(judge_output.diff.score());
342 thread_scores.push(judge_output.thread.score());
343 println!(
344 "|{:^7}│{:^6}│{:^8}│ {}",
345 example.repetition,
346 format!("{}%", judge_output.diff.score()),
347 format!("{}%", judge_output.thread.score()),
348 relative_run_dir_path.display()
349 );
350 }
351 Err(err) => {
352 println!(
353 "|{:^7}│{:^6}│{:^8}│{:?}│ {}",
354 example.repetition,
355 "N/A",
356 "N/A",
357 err,
358 relative_run_dir_path.display()
359 );
360 }
361 }
362 }
363 }
364 }
365
366 println!("└───────┴──────┴────────┘");
367 println!("{}", example_cumulative_tool_metrics);
368 }
369
370 let diff_score_count = diff_scores.len();
371 let average_diff_score = diff_scores
372 .into_iter()
373 .map(|score| score as f32)
374 .sum::<f32>()
375 / (diff_score_count as f32);
376
377 if error_count > 0 {
378 println!("\n{error_count} examples failed to run!");
379 }
380
381 println!("\nAverage code diff score: {average_diff_score}");
382
383 let thread_score_count = thread_scores.len();
384 let average_thread_score = thread_scores
385 .into_iter()
386 .map(|score| score as f32)
387 .sum::<f32>()
388 / (thread_score_count as f32);
389
390 println!("\nAverage thread score: {average_thread_score}");
391
392 print_header("CUMULATIVE TOOL METRICS");
393 println!("{}", cumulative_tool_metrics);
394
395 app_state.client.telemetry().flush_events().await;
396
397 cx.update(|cx| cx.quit())
398 })
399 .detach_and_log_err(cx);
400 });
401}
402
403fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
404 let path = std::fs::canonicalize(examples_dir).unwrap();
405 let entries = std::fs::read_dir(path).unwrap();
406 let mut result_paths = Vec::new();
407 for entry in entries {
408 let entry = entry?;
409 let path = entry.path();
410 if path.is_dir() {
411 result_paths.push(path);
412 }
413 }
414 Ok(result_paths)
415}
416
417/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
418pub struct AgentAppState {
419 pub languages: Arc<LanguageRegistry>,
420 pub client: Arc<Client>,
421 pub user_store: Entity<UserStore>,
422 pub fs: Arc<dyn fs::Fs>,
423 pub node_runtime: NodeRuntime,
424
425 // Additional fields not present in `workspace::AppState`.
426 pub prompt_builder: Arc<PromptBuilder>,
427}
428
429pub fn init(cx: &mut App) -> Arc<AgentAppState> {
430 release_channel::init(SemanticVersion::default(), cx);
431 gpui_tokio::init(cx);
432
433 let mut settings_store = SettingsStore::new(cx);
434 settings_store
435 .set_default_settings(settings::default_settings().as_ref(), cx)
436 .unwrap();
437 cx.set_global(settings_store);
438 client::init_settings(cx);
439
440 // Set User-Agent so we can download language servers from GitHub
441 let user_agent = format!(
442 "Zed/{} ({}; {})",
443 AppVersion::global(cx),
444 std::env::consts::OS,
445 std::env::consts::ARCH
446 );
447 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
448 let proxy_url = proxy_str
449 .as_ref()
450 .and_then(|input| input.parse::<Uri>().ok())
451 .or_else(read_proxy_from_env);
452 let http = {
453 let _guard = Tokio::handle(cx).enter();
454
455 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
456 .expect("could not start HTTP client")
457 };
458 cx.set_http_client(Arc::new(http));
459
460 Project::init_settings(cx);
461
462 let client = Client::production(cx);
463 cx.set_http_client(client.http_client().clone());
464
465 let git_binary_path = None;
466 let fs = Arc::new(RealFs::new(
467 git_binary_path,
468 cx.background_executor().clone(),
469 ));
470
471 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
472 languages.set_language_server_download_dir(paths::languages_dir().clone());
473 let languages = Arc::new(languages);
474
475 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
476
477 extension::init(cx);
478
479 let (tx, rx) = async_watch::channel(None);
480 cx.observe_global::<SettingsStore>(move |cx| {
481 let settings = &ProjectSettings::get_global(cx).node;
482 let options = NodeBinaryOptions {
483 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
484 allow_binary_download: true,
485 use_paths: settings.path.as_ref().map(|node_path| {
486 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
487 let npm_path = settings
488 .npm_path
489 .as_ref()
490 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
491 (
492 node_path.clone(),
493 npm_path.unwrap_or_else(|| {
494 let base_path = PathBuf::new();
495 node_path.parent().unwrap_or(&base_path).join("npm")
496 }),
497 )
498 }),
499 };
500 tx.send(Some(options)).log_err();
501 })
502 .detach();
503 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
504
505 let extension_host_proxy = ExtensionHostProxy::global(cx);
506
507 language::init(cx);
508 language_extension::init(extension_host_proxy.clone(), languages.clone());
509 language_model::init(client.clone(), cx);
510 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
511 languages::init(languages.clone(), node_runtime.clone(), cx);
512 assistant_tools::init(client.http_client().clone(), cx);
513 context_server::init(cx);
514 prompt_store::init(cx);
515 let stdout_is_a_pty = false;
516 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
517 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
518
519 SettingsStore::update_global(cx, |store, cx| {
520 store.set_user_settings(include_str!("../runner_settings.json"), cx)
521 })
522 .unwrap();
523
524 Arc::new(AgentAppState {
525 languages,
526 client,
527 user_store,
528 fs,
529 node_runtime,
530 prompt_builder,
531 })
532}
533
534pub fn find_model(
535 model_name: &str,
536 model_registry: &LanguageModelRegistry,
537 cx: &App,
538) -> anyhow::Result<Arc<dyn LanguageModel>> {
539 let model = model_registry
540 .available_models(cx)
541 .find(|model| model.id().0 == model_name);
542
543 let Some(model) = model else {
544 return Err(anyhow!(
545 "No language model named {} was available. Available models: {}",
546 model_name,
547 model_registry
548 .available_models(cx)
549 .map(|model| model.id().0.clone())
550 .collect::<Vec<_>>()
551 .join(", ")
552 ));
553 };
554
555 Ok(model)
556}
557
558pub fn commit_sha_for_path(repo_path: &Path) -> String {
559 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
560}
561
562pub fn git_branch_for_path(repo_path: &Path) -> String {
563 match std::env::var("GITHUB_REF_NAME") {
564 Ok(branch) => branch,
565 Err(_) => {
566 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
567 .unwrap_or_else(|_| "unknown".to_string())
568 }
569 }
570}
571
572async fn judge_example(
573 example: Example,
574 model: Arc<dyn LanguageModel>,
575 zed_commit_sha: &str,
576 zed_branch_name: &str,
577 run_id: &str,
578 run_output: &RunOutput,
579 enable_telemetry: bool,
580 cx: &AsyncApp,
581) -> Result<JudgeOutput> {
582 let judge_output = example.judge(model.clone(), &run_output, cx).await;
583
584 let diff_evaluation;
585 let thread_evaluation;
586 if let Ok(output) = judge_output.as_ref() {
587 diff_evaluation = Some(output.diff.clone());
588 thread_evaluation = Some(output.thread.clone());
589 } else {
590 diff_evaluation = None;
591 thread_evaluation = None;
592 }
593
594 if enable_telemetry {
595 telemetry::event!(
596 "Agent Example Evaluated",
597 zed_commit_sha = zed_commit_sha,
598 zed_branch_name = zed_branch_name,
599 run_id = run_id,
600 example_name = example.name.clone(),
601 example_repetition = example.repetition,
602 diff_evaluation = diff_evaluation,
603 thread_evaluation = thread_evaluation,
604 tool_metrics = run_output.tool_metrics,
605 response_count = run_output.response_count,
606 token_usage = run_output.token_usage,
607 model = model.telemetry_id(),
608 model_provider = model.provider_id().to_string(),
609 repository_url = example.base.url.clone(),
610 repository_revision = example.base.revision.clone(),
611 diagnostic_summary_before = run_output.diagnostic_summary_before,
612 diagnostic_summary_after = run_output.diagnostic_summary_after,
613 diagnostics_before = run_output.diagnostics_before,
614 diagnostics_after = run_output.diagnostics_after,
615 );
616 }
617
618 judge_output
619}
620
621fn print_header(header: &str) {
622 println!("\n========================================");
623 println!("{:^40}", header);
624 println!("========================================\n");
625}