1mod example;
2mod ids;
3mod tool_metrics;
4
5pub(crate) use example::*;
6pub(crate) use tool_metrics::*;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use client::{Client, ProxySettings, UserStore};
12use collections::HashSet;
13use extension::ExtensionHostProxy;
14use futures::{StreamExt, future};
15use gpui::http_client::{Uri, read_proxy_from_env};
16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
17use gpui_tokio::Tokio;
18use language::LanguageRegistry;
19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
20use node_runtime::{NodeBinaryOptions, NodeRuntime};
21use project::Project;
22use project::project_settings::ProjectSettings;
23use prompt_store::PromptBuilder;
24use release_channel::AppVersion;
25use reqwest_client::ReqwestClient;
26use settings::{Settings, SettingsStore};
27use std::env;
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30use util::ResultExt as _;
31
32#[derive(Parser, Debug)]
33#[command(name = "eval", disable_version_flag = true)]
34struct Args {
35 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
36 #[arg(value_name = "EXAMPLE_SUBSTRING")]
37 examples: Vec<String>,
38 /// Model to use (default: "claude-3-7-sonnet-latest")
39 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
40 model: String,
41 #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
42 languages: Vec<String>,
43 /// How many times to run each example. Note that this is currently not very efficient as N
44 /// worktrees will be created for the examples.
45 #[arg(long, default_value = "1")]
46 repetitions: u32,
47 /// How many times to run the judge on each example run.
48 #[arg(long, default_value = "3")]
49 judge_repetitions: u32,
50 /// Maximum number of examples to run concurrently.
51 #[arg(long, default_value = "10")]
52 concurrency: usize,
53}
54
55fn main() {
56 env_logger::init();
57
58 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
59 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
60 let session_id = uuid::Uuid::new_v4().to_string();
61 let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
62 let run_id = match env::var("GITHUB_RUN_ID") {
63 Ok(run_id) => format!("github/{}", run_id),
64 Err(_) => format!("local/{}", run_timestamp),
65 };
66
67 let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
68 .parent()
69 .unwrap()
70 .parent()
71 .unwrap();
72 let eval_crate_dir = root_dir.join("crates/eval");
73 let repos_dir = eval_crate_dir.join("repos");
74 let worktrees_dir = eval_crate_dir.join("worktrees");
75 let examples_dir = eval_crate_dir.join("examples");
76 let runs_dir = eval_crate_dir.join("runs");
77 let run_dir = runs_dir.join(format!("{}", run_timestamp));
78 std::fs::create_dir_all(&run_dir).unwrap();
79 std::fs::create_dir_all(&repos_dir).unwrap();
80 std::fs::create_dir_all(&worktrees_dir).unwrap();
81 std::fs::create_dir_all(&examples_dir).unwrap();
82 std::fs::create_dir_all(&paths::config_dir()).unwrap();
83
84 let zed_commit_sha = commit_sha_for_path(root_dir);
85 let zed_branch_name = git_branch_for_path(root_dir);
86 let args = Args::parse();
87 let all_available_examples = list_all_examples(&examples_dir).unwrap();
88
89 let example_paths = all_available_examples
90 .iter()
91 .filter_map(|example_path| {
92 let name = example_path.file_name()?.to_string_lossy();
93 if args.examples.is_empty()
94 || args
95 .examples
96 .iter()
97 .any(|name_substring| name.contains(name_substring))
98 {
99 Some(example_path.clone())
100 } else {
101 None
102 }
103 })
104 .collect::<Vec<_>>();
105
106 let http_client = Arc::new(ReqwestClient::new());
107 let app = Application::headless().with_http_client(http_client.clone());
108
109 app.run(move |cx| {
110 let app_state = init(cx);
111
112 let telemetry = app_state.client.telemetry();
113 telemetry.start(system_id, installation_id, session_id, cx);
114
115 let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
116 && telemetry.has_checksum_seed();
117 if enable_telemetry {
118 println!("Telemetry enabled");
119 telemetry::event!(
120 "Agent Eval Started",
121 zed_commit_sha = zed_commit_sha,
122 zed_branch_name = zed_branch_name,
123 run_id = run_id,
124 );
125 }
126
127 let mut cumulative_tool_metrics = ToolMetrics::default();
128
129 let model_registry = LanguageModelRegistry::read_global(cx);
130 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
131 let model_provider_id = model.provider_id();
132 let model_provider = model_registry.provider(&model_provider_id).unwrap();
133
134 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
135 registry.set_default_model(
136 Some(ConfiguredModel {
137 provider: model_provider.clone(),
138 model: model.clone(),
139 }),
140 cx,
141 );
142 });
143
144 let authenticate_task = model_provider.authenticate(cx);
145
146 cx.spawn(async move |cx| {
147 authenticate_task.await.unwrap();
148
149 let mut examples = Vec::new();
150
151 const COLORS: [&str; 12] = [
152 "\x1b[31m", // Red
153 "\x1b[32m", // Green
154 "\x1b[33m", // Yellow
155 "\x1b[34m", // Blue
156 "\x1b[35m", // Magenta
157 "\x1b[36m", // Cyan
158 "\x1b[91m", // Bright Red
159 "\x1b[92m", // Bright Green
160 "\x1b[93m", // Bright Yellow
161 "\x1b[94m", // Bright Blue
162 "\x1b[95m", // Bright Magenta
163 "\x1b[96m", // Bright Cyan
164 ];
165
166 let mut max_name_width = 0;
167 let mut skipped = Vec::new();
168
169 for example_path in &example_paths {
170 let example = Example::load_from_directory(
171 example_path,
172 &run_dir,
173 &worktrees_dir,
174 &repos_dir,
175 )?;
176
177 if !example
178 .base
179 .language_extension
180 .as_ref()
181 .map_or(false, |lang| args.languages.contains(lang))
182 {
183 skipped.push(example.name);
184 continue;
185 }
186
187 // TODO: This creates a worktree per repetition. Ideally these examples should
188 // either be run sequentially on the same worktree, or reuse worktrees when there
189 // are more examples to run than the concurrency limit.
190 for repetition_number in 0..args.repetitions {
191 let mut example = example.clone();
192 example.set_repetition_number(repetition_number);
193
194 let name_len = example.name.len();
195 if name_len > max_name_width {
196 max_name_width = example.name.len();
197 }
198
199 examples.push(example);
200 }
201 }
202
203 println!("Skipped examples: {}\n", skipped.join(", "));
204
205 if examples.is_empty() {
206 eprintln!("Filter matched no examples");
207 return cx.update(|cx| cx.quit());
208 }
209
210 let mut repo_urls = HashSet::default();
211 let mut clone_tasks = Vec::new();
212
213 for (i, example) in examples.iter_mut().enumerate() {
214 let color = COLORS[i % COLORS.len()].to_string();
215 example.set_log_prefix_style(&color, max_name_width);
216
217 println!(
218 "{}Logging to: {}",
219 example.log_prefix,
220 example.example_output_directory().display()
221 );
222
223 let repo_url = example.base.url.clone();
224 if repo_urls.insert(repo_url.clone()) {
225 let repo_path = example.repo_path.clone();
226
227 if !repo_path.join(".git").is_dir() {
228 println!(
229 "{:<width$} < {}",
230 "↓ Cloning",
231 repo_url,
232 width = max_name_width
233 );
234
235 let git_task = cx.spawn(async move |_cx| {
236 std::fs::create_dir_all(&repo_path)?;
237 run_git(&repo_path, &["init"]).await?;
238 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
239 });
240
241 clone_tasks.push(git_task);
242 } else {
243 println!(
244 "{:<width$} < {}",
245 "✔︎ Already cloned",
246 repo_url,
247 width = max_name_width
248 );
249
250 let actual_origin =
251 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
252 if actual_origin != repo_url {
253 return Err(anyhow!(
254 "remote origin {} does not match expected origin {}",
255 actual_origin,
256 repo_url,
257 ));
258 }
259 }
260 }
261 }
262
263 future::join_all(clone_tasks).await;
264
265 for example in examples.iter_mut() {
266 example.setup().await?;
267 }
268
269 let judge_repetitions = args.judge_repetitions;
270 let concurrency = args.concurrency;
271
272 let tasks = examples.iter().map(|example| {
273 let app_state = app_state.clone();
274 let model = model.clone();
275 let example = example.clone();
276 let zed_commit_sha = zed_commit_sha.clone();
277 let zed_branch_name = zed_branch_name.clone();
278 let run_id = run_id.clone();
279 cx.spawn(async move |cx| {
280 let result = async {
281 let run_output = cx
282 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
283 .await?;
284 let judge_tasks = (0..judge_repetitions).map(|round| {
285 run_judge_repetition(
286 example.clone(),
287 model.clone(),
288 &zed_commit_sha,
289 &zed_branch_name,
290 &run_id,
291 &run_output,
292 round,
293 enable_telemetry,
294 cx,
295 )
296 });
297 let judge_outputs = future::join_all(judge_tasks).await;
298 anyhow::Ok((run_output, judge_outputs))
299 }
300 .await;
301 (example, result)
302 })
303 });
304
305 let results = futures::stream::iter(tasks)
306 .buffer_unordered(concurrency)
307 .collect::<Vec<_>>()
308 .await;
309
310 println!("\n\n");
311 print_header("EVAL RESULTS");
312
313 let mut diff_scores = Vec::new();
314 let mut thread_scores = Vec::new();
315 let mut error_count = 0;
316
317 for (example, result) in results {
318 print_header(&example.name);
319
320 match result {
321 Err(err) => {
322 println!("💥 {}{:?}", example.log_prefix, err);
323 error_count += 1;
324 }
325 Ok((run_output, judge_results)) => {
326 cumulative_tool_metrics.merge(&run_output.tool_metrics);
327
328 println!("┌───────┬──────┬────────┐");
329 println!("│ Judge │ Diff │ Thread │");
330 println!("├───────┼──────┼────────┤");
331
332 for (i, judge_result) in judge_results.iter().enumerate() {
333 match judge_result {
334 Ok(judge_output) => {
335 let diff_score = judge_output.diff.score;
336 diff_scores.push(diff_score);
337
338 let thread_display = if let Some(thread) = &judge_output.thread
339 {
340 let thread_score = thread.score;
341 thread_scores.push(thread_score);
342 format!("{}", thread_score)
343 } else {
344 "N/A".to_string()
345 };
346
347 println!(
348 "|{:^7}│{:^6}│{:^8}│",
349 i + 1,
350 diff_score,
351 thread_display
352 );
353 }
354 Err(err) => {
355 println!("|{:^7}│{:^6}│{:^8}│{:?}", i + 1, "N/A", "N/A", err);
356 }
357 }
358 }
359
360 println!("└───────┴──────┴────────┘");
361
362 println!("{}", run_output.tool_metrics);
363 }
364 }
365 println!(
366 "{} > {}",
367 " ".repeat(max_name_width),
368 example.example_output_directory().display()
369 );
370 }
371
372 let diff_score_count = diff_scores.len();
373 let average_diff_score = diff_scores
374 .into_iter()
375 .map(|score| score as f32)
376 .sum::<f32>()
377 / (diff_score_count as f32);
378
379 if error_count > 0 {
380 println!("\n{error_count} examples failed to run!");
381 }
382
383 if diff_score_count > 0 {
384 println!("\nAverage code diff score: {average_diff_score}");
385 }
386
387 let thread_score_count = thread_scores.len();
388
389 // We might have gotten no thread scores if we weren't asked to judge the thread.
390 if thread_score_count > 0 {
391 let average_thread_score = thread_scores
392 .into_iter()
393 .map(|score| score as f32)
394 .sum::<f32>()
395 / (thread_score_count as f32);
396
397 if diff_score_count > 0 {
398 println!("\nAverage thread score: {average_thread_score}");
399 }
400 }
401
402 print_header("CUMULATIVE TOOL METRICS");
403 println!("{}", cumulative_tool_metrics);
404
405 app_state.client.telemetry().flush_events().await;
406
407 cx.update(|cx| cx.quit())
408 })
409 .detach_and_log_err(cx);
410 });
411}
412
413fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
414 let path = std::fs::canonicalize(examples_dir).unwrap();
415 let entries = std::fs::read_dir(path).unwrap();
416 let mut result_paths = Vec::new();
417 for entry in entries {
418 let entry = entry?;
419 let path = entry.path();
420 if path.is_dir() {
421 result_paths.push(path);
422 }
423 }
424 Ok(result_paths)
425}
426
427/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
428pub struct AgentAppState {
429 pub languages: Arc<LanguageRegistry>,
430 pub client: Arc<Client>,
431 pub user_store: Entity<UserStore>,
432 pub fs: Arc<dyn fs::Fs>,
433 pub node_runtime: NodeRuntime,
434
435 // Additional fields not present in `workspace::AppState`.
436 pub prompt_builder: Arc<PromptBuilder>,
437}
438
439pub fn init(cx: &mut App) -> Arc<AgentAppState> {
440 release_channel::init(SemanticVersion::default(), cx);
441 gpui_tokio::init(cx);
442
443 let mut settings_store = SettingsStore::new(cx);
444 settings_store
445 .set_default_settings(settings::default_settings().as_ref(), cx)
446 .unwrap();
447 cx.set_global(settings_store);
448 client::init_settings(cx);
449
450 // Set User-Agent so we can download language servers from GitHub
451 let user_agent = format!(
452 "Zed/{} ({}; {})",
453 AppVersion::global(cx),
454 std::env::consts::OS,
455 std::env::consts::ARCH
456 );
457 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
458 let proxy_url = proxy_str
459 .as_ref()
460 .and_then(|input| input.parse::<Uri>().ok())
461 .or_else(read_proxy_from_env);
462 let http = {
463 let _guard = Tokio::handle(cx).enter();
464
465 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
466 .expect("could not start HTTP client")
467 };
468 cx.set_http_client(Arc::new(http));
469
470 Project::init_settings(cx);
471
472 let client = Client::production(cx);
473 cx.set_http_client(client.http_client().clone());
474
475 let git_binary_path = None;
476 let fs = Arc::new(RealFs::new(
477 git_binary_path,
478 cx.background_executor().clone(),
479 ));
480
481 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
482 languages.set_language_server_download_dir(paths::languages_dir().clone());
483 let languages = Arc::new(languages);
484
485 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
486
487 extension::init(cx);
488
489 let (tx, rx) = async_watch::channel(None);
490 cx.observe_global::<SettingsStore>(move |cx| {
491 let settings = &ProjectSettings::get_global(cx).node;
492 let options = NodeBinaryOptions {
493 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
494 allow_binary_download: true,
495 use_paths: settings.path.as_ref().map(|node_path| {
496 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
497 let npm_path = settings
498 .npm_path
499 .as_ref()
500 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
501 (
502 node_path.clone(),
503 npm_path.unwrap_or_else(|| {
504 let base_path = PathBuf::new();
505 node_path.parent().unwrap_or(&base_path).join("npm")
506 }),
507 )
508 }),
509 };
510 tx.send(Some(options)).log_err();
511 })
512 .detach();
513 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
514
515 let extension_host_proxy = ExtensionHostProxy::global(cx);
516
517 language::init(cx);
518 language_extension::init(extension_host_proxy.clone(), languages.clone());
519 language_model::init(client.clone(), cx);
520 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
521 languages::init(languages.clone(), node_runtime.clone(), cx);
522 assistant_tools::init(client.http_client().clone(), cx);
523 context_server::init(cx);
524 prompt_store::init(cx);
525 let stdout_is_a_pty = false;
526 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
527 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
528
529 SettingsStore::update_global(cx, |store, cx| {
530 store.set_user_settings(include_str!("../runner_settings.json"), cx)
531 })
532 .unwrap();
533
534 Arc::new(AgentAppState {
535 languages,
536 client,
537 user_store,
538 fs,
539 node_runtime,
540 prompt_builder,
541 })
542}
543
544pub fn find_model(
545 model_name: &str,
546 model_registry: &LanguageModelRegistry,
547 cx: &App,
548) -> anyhow::Result<Arc<dyn LanguageModel>> {
549 let model = model_registry
550 .available_models(cx)
551 .find(|model| model.id().0 == model_name);
552
553 let Some(model) = model else {
554 return Err(anyhow!(
555 "No language model named {} was available. Available models: {}",
556 model_name,
557 model_registry
558 .available_models(cx)
559 .map(|model| model.id().0.clone())
560 .collect::<Vec<_>>()
561 .join(", ")
562 ));
563 };
564
565 Ok(model)
566}
567
568pub fn commit_sha_for_path(repo_path: &Path) -> String {
569 futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
570}
571
572pub fn git_branch_for_path(repo_path: &Path) -> String {
573 match std::env::var("GITHUB_REF_NAME") {
574 Ok(branch) => branch,
575 Err(_) => {
576 futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
577 .unwrap_or_else(|_| "unknown".to_string())
578 }
579 }
580}
581
582async fn run_judge_repetition(
583 example: Example,
584 model: Arc<dyn LanguageModel>,
585 zed_commit_sha: &str,
586 zed_branch_name: &str,
587 run_id: &str,
588 run_output: &RunOutput,
589 round: u32,
590 enable_telemetry: bool,
591 cx: &AsyncApp,
592) -> Result<JudgeOutput> {
593 let judge_output = example.judge(model.clone(), &run_output, round, cx).await;
594
595 let diff_evaluation;
596 let thread_diff_evaluation;
597 if let Ok(output) = judge_output.as_ref() {
598 diff_evaluation = Some(output.diff.clone());
599 thread_diff_evaluation = output.thread.clone();
600 } else {
601 diff_evaluation = None;
602 thread_diff_evaluation = None;
603 }
604
605 if enable_telemetry {
606 telemetry::event!(
607 "Agent Example Evaluated",
608 zed_commit_sha = zed_commit_sha,
609 zed_branch_name = zed_branch_name,
610 run_id = run_id,
611 example_name = example.name.clone(),
612 round = round,
613 diff_evaluation = diff_evaluation,
614 thread_evaluation = thread_diff_evaluation,
615 tool_metrics = run_output.tool_metrics,
616 response_count = run_output.response_count,
617 token_usage = run_output.token_usage,
618 model = model.telemetry_id(),
619 model_provider = model.provider_id().to_string(),
620 repository_url = example.base.url.clone(),
621 repository_revision = example.base.revision.clone(),
622 diagnostics_before = run_output.diagnostics_before,
623 diagnostics_after = run_output.diagnostics_after,
624 );
625 }
626
627 judge_output
628}
629
630fn print_header(header: &str) {
631 println!("\n========================================");
632 println!("{:^40}", header);
633 println!("========================================\n");
634}