1mod example;
2mod ids;
3
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6use telemetry;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use extension::ExtensionHostProxy;
12use futures::{StreamExt, future};
13use gpui::http_client::{Uri, read_proxy_from_env};
14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
15use gpui_tokio::Tokio;
16use language::LanguageRegistry;
17use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
18use node_runtime::{NodeBinaryOptions, NodeRuntime};
19use project::Project;
20use project::project_settings::ProjectSettings;
21use prompt_store::PromptBuilder;
22use release_channel::AppVersion;
23use reqwest_client::ReqwestClient;
24use settings::{Settings, SettingsStore};
25use std::collections::HashSet;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use std::usize;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42 #[arg(long, value_delimiter = ',')]
43 languages: Option<Vec<String>>,
44 /// How many times to run each example. Note that this is currently not very efficient as N
45 /// worktrees will be created for the examples.
46 #[arg(long, default_value = "1")]
47 repetitions: u32,
48 /// How many times to run the judge on each example run.
49 #[arg(long, default_value = "3")]
50 judge_repetitions: u32,
51 /// Maximum number of examples to run concurrently.
52 #[arg(long, default_value = "10")]
53 concurrency: usize,
54}
55
56fn main() {
57 env_logger::init();
58
59 let args = Args::parse();
60 let all_available_examples = list_all_examples().unwrap();
61 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
62
63 let example_paths = all_available_examples
64 .iter()
65 .filter_map(|example_path| {
66 let name = example_path.file_name()?.to_string_lossy();
67 if args.examples.is_empty()
68 || args
69 .examples
70 .iter()
71 .any(|name_substring| name.contains(name_substring))
72 {
73 Some(example_path.clone())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<_>>();
79
80 let http_client = Arc::new(ReqwestClient::new());
81 let app = Application::headless().with_http_client(http_client.clone());
82
83 app.run(move |cx| {
84 let app_state = init(cx);
85
86 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
87 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
88 let session_id = uuid::Uuid::new_v4().to_string();
89
90 app_state
91 .client
92 .telemetry()
93 .start(system_id, installation_id, session_id, cx);
94
95 let model_registry = LanguageModelRegistry::read_global(cx);
96 let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
97 let model_provider_id = model.provider_id();
98 let model_provider = model_registry.provider(&model_provider_id).unwrap();
99
100 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
101 registry.set_default_model(
102 Some(ConfiguredModel {
103 provider: model_provider.clone(),
104 model: model.clone(),
105 }),
106 cx,
107 );
108 });
109
110 let authenticate_task = model_provider.authenticate(cx);
111
112 cx.spawn(async move |cx| {
113 authenticate_task.await.unwrap();
114
115 std::fs::create_dir_all(REPOS_DIR)?;
116 std::fs::create_dir_all(WORKTREES_DIR)?;
117
118 let run_dir = Path::new(RUNS_DIR).join(format!(
119 "{}",
120 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
121 ));
122 std::fs::create_dir_all(&run_dir)?;
123
124 let mut examples = Vec::new();
125
126 const COLORS: [&str; 12] = [
127 "\x1b[31m", // Red
128 "\x1b[32m", // Green
129 "\x1b[33m", // Yellow
130 "\x1b[34m", // Blue
131 "\x1b[35m", // Magenta
132 "\x1b[36m", // Cyan
133 "\x1b[91m", // Bright Red
134 "\x1b[92m", // Bright Green
135 "\x1b[93m", // Bright Yellow
136 "\x1b[94m", // Bright Blue
137 "\x1b[95m", // Bright Magenta
138 "\x1b[96m", // Bright Cyan
139 ];
140
141 let mut max_name_width = 0;
142 let mut skipped = Vec::new();
143
144 for example_path in &example_paths {
145 let example = Example::load_from_directory(example_path, &run_dir)?;
146
147 if !example
148 .base
149 .language_extension
150 .as_ref()
151 .map_or(false, |lang| languages.contains(lang))
152 {
153 skipped.push(example.name);
154 continue;
155 }
156
157 // TODO: This creates a worktree per repetition. Ideally these examples should
158 // either be run sequentially on the same worktree, or reuse worktrees when there
159 // are more examples to run than the concurrency limit.
160 for repetition_number in 0..args.repetitions {
161 let mut example = example.clone();
162 example.set_repetition_number(repetition_number);
163
164 let name_len = example.name.len();
165 if name_len > max_name_width {
166 max_name_width = example.name.len();
167 }
168
169 examples.push(example);
170 }
171 }
172
173 println!("Skipped examples: {}\n", skipped.join(", "));
174
175 if examples.is_empty() {
176 eprintln!("Filter matched no examples");
177 return cx.update(|cx| cx.quit());
178 }
179
180 let mut repo_urls = HashSet::new();
181 let mut clone_tasks = Vec::new();
182
183 for (i, example) in examples.iter_mut().enumerate() {
184 let color = COLORS[i % COLORS.len()].to_string();
185 example.set_log_prefix_style(&color, max_name_width);
186
187 println!(
188 "{}Logging to: {}",
189 example.log_prefix,
190 example.example_output_directory().display()
191 );
192
193 let repo_url = example.base.url.clone();
194 if repo_urls.insert(repo_url.clone()) {
195 let repo_path = repo_path_for_url(&repo_url);
196
197 if !repo_path.join(".git").is_dir() {
198 println!(
199 "{:<width$} < {}",
200 "↓ Cloning",
201 repo_url,
202 width = max_name_width
203 );
204
205 let git_task = cx.spawn(async move |_cx| {
206 std::fs::create_dir_all(&repo_path)?;
207 run_git(&repo_path, &["init"]).await?;
208 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
209 });
210
211 clone_tasks.push(git_task);
212 } else {
213 println!(
214 "{:<width$} < {}",
215 "✔︎ Already cloned",
216 repo_url,
217 width = max_name_width
218 );
219
220 let actual_origin =
221 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
222 if actual_origin != repo_url {
223 return Err(anyhow!(
224 "remote origin {} does not match expected origin {}",
225 actual_origin,
226 repo_url,
227 ));
228 }
229 }
230 }
231 }
232
233 future::join_all(clone_tasks).await;
234
235 for example in examples.iter_mut() {
236 example.setup().await?;
237 }
238
239 let judge_repetitions = args.judge_repetitions;
240 let concurrency = args.concurrency;
241
242 let tasks = examples.iter().map(|example| {
243 let app_state = app_state.clone();
244 let model = model.clone();
245 let example = example.clone();
246 cx.spawn(async move |cx| {
247 let result =
248 run_example(&example, model, app_state, judge_repetitions, cx).await;
249 (result, example)
250 })
251 });
252
253 let results = futures::stream::iter(tasks)
254 .buffer_unordered(concurrency)
255 .collect::<Vec<_>>()
256 .await;
257
258 println!("\n\n");
259 println!("========================================");
260 println!(" EVAL RESULTS ");
261 println!("========================================");
262 println!("");
263
264 let mut diff_scores = Vec::new();
265 let mut thread_scores = Vec::new();
266 let mut error_count = 0;
267
268 for (result, example) in results {
269 match result {
270 Err(err) => {
271 println!("💥 {}{:?}", example.log_prefix, err);
272 error_count += 1;
273 }
274 Ok(judge_results) => {
275 for judge_result in judge_results {
276 match judge_result {
277 Ok(judge_output) => {
278 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
279 let diff_score: u32 = judge_output.diff.score;
280 let score_index = (diff_score.min(5)) as usize;
281
282 println!(
283 "{} {}{} (Diff)",
284 SCORES[score_index],
285 example.log_prefix,
286 judge_output.diff.score,
287 );
288 diff_scores.push(judge_output.diff.score);
289
290 if let Some(thread) = judge_output.thread {
291 let process_score: u32 = thread.score;
292 let score_index = (process_score.min(5)) as usize;
293 println!(
294 "{} {}{} (Thread)",
295 SCORES[score_index], example.log_prefix, thread.score,
296 );
297 thread_scores.push(thread.score);
298 }
299 }
300 Err(err) => {
301 println!("💥 {}{:?}", example.log_prefix, err);
302 }
303 }
304 }
305 }
306 }
307 println!(
308 "{} > {}",
309 " ".repeat(max_name_width),
310 example.example_output_directory().display()
311 );
312 }
313
314 let diff_score_count = diff_scores.len();
315 let average_diff_score = diff_scores
316 .into_iter()
317 .map(|score| score as f32)
318 .sum::<f32>()
319 / (diff_score_count as f32);
320
321 if error_count > 0 {
322 println!("\n{error_count} examples failed to run!");
323 }
324
325 if diff_score_count > 0 {
326 println!("\nAverage code diff score: {average_diff_score}");
327 }
328
329 let thread_score_count = thread_scores.len();
330
331 // We might have gotten no thread scores if we weren't asked to judge the thread.
332 if thread_score_count > 0 {
333 let average_thread_score = thread_scores
334 .into_iter()
335 .map(|score| score as f32)
336 .sum::<f32>()
337 / (thread_score_count as f32);
338
339 if diff_score_count > 0 {
340 println!("\nAverage thread score: {average_thread_score}");
341 }
342 }
343
344 std::thread::sleep(std::time::Duration::from_secs(2));
345
346 app_state.client.telemetry().flush_events();
347
348 cx.update(|cx| cx.quit())
349 })
350 .detach_and_log_err(cx);
351 });
352}
353
354async fn run_example(
355 example: &Example,
356 model: Arc<dyn LanguageModel>,
357 app_state: Arc<AgentAppState>,
358 judge_repetitions: u32,
359 cx: &mut AsyncApp,
360) -> Result<Vec<Result<JudgeOutput>>> {
361 let run_output = cx
362 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
363 .await?;
364
365 let judge_tasks = (0..judge_repetitions)
366 .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
367
368 let results = future::join_all(judge_tasks).await;
369
370 app_state.client.telemetry().flush_events();
371
372 Ok(results)
373}
374
375fn list_all_examples() -> Result<Vec<PathBuf>> {
376 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
377 let entries = std::fs::read_dir(path).unwrap();
378 let mut result_paths = Vec::new();
379 for entry in entries {
380 let entry = entry?;
381 let path = entry.path();
382 if path.is_dir() {
383 result_paths.push(path);
384 }
385 }
386 Ok(result_paths)
387}
388
389/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
390pub struct AgentAppState {
391 pub languages: Arc<LanguageRegistry>,
392 pub client: Arc<Client>,
393 pub user_store: Entity<UserStore>,
394 pub fs: Arc<dyn fs::Fs>,
395 pub node_runtime: NodeRuntime,
396
397 // Additional fields not present in `workspace::AppState`.
398 pub prompt_builder: Arc<PromptBuilder>,
399}
400
401pub fn init(cx: &mut App) -> Arc<AgentAppState> {
402 release_channel::init(SemanticVersion::default(), cx);
403 gpui_tokio::init(cx);
404
405 let mut settings_store = SettingsStore::new(cx);
406 settings_store
407 .set_default_settings(settings::default_settings().as_ref(), cx)
408 .unwrap();
409 cx.set_global(settings_store);
410 client::init_settings(cx);
411
412 // Set User-Agent so we can download language servers from GitHub
413 let user_agent = format!(
414 "Zed/{} ({}; {})",
415 AppVersion::global(cx),
416 std::env::consts::OS,
417 std::env::consts::ARCH
418 );
419 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
420 let proxy_url = proxy_str
421 .as_ref()
422 .and_then(|input| input.parse::<Uri>().ok())
423 .or_else(read_proxy_from_env);
424 let http = {
425 let _guard = Tokio::handle(cx).enter();
426
427 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
428 .expect("could not start HTTP client")
429 };
430 cx.set_http_client(Arc::new(http));
431
432 Project::init_settings(cx);
433
434 let client = Client::production(cx);
435 cx.set_http_client(client.http_client().clone());
436
437 let git_binary_path = None;
438 let fs = Arc::new(RealFs::new(
439 git_binary_path,
440 cx.background_executor().clone(),
441 ));
442
443 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
444 languages.set_language_server_download_dir(paths::languages_dir().clone());
445 let languages = Arc::new(languages);
446
447 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
448
449 extension::init(cx);
450
451 let (tx, rx) = async_watch::channel(None);
452 cx.observe_global::<SettingsStore>(move |cx| {
453 let settings = &ProjectSettings::get_global(cx).node;
454 let options = NodeBinaryOptions {
455 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
456 allow_binary_download: true,
457 use_paths: settings.path.as_ref().map(|node_path| {
458 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
459 let npm_path = settings
460 .npm_path
461 .as_ref()
462 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
463 (
464 node_path.clone(),
465 npm_path.unwrap_or_else(|| {
466 let base_path = PathBuf::new();
467 node_path.parent().unwrap_or(&base_path).join("npm")
468 }),
469 )
470 }),
471 };
472 tx.send(Some(options)).log_err();
473 })
474 .detach();
475 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
476
477 let extension_host_proxy = ExtensionHostProxy::global(cx);
478
479 language::init(cx);
480 language_extension::init(extension_host_proxy.clone(), languages.clone());
481 language_model::init(client.clone(), cx);
482 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
483 languages::init(languages.clone(), node_runtime.clone(), cx);
484 assistant_tools::init(client.http_client().clone(), cx);
485 context_server::init(cx);
486 prompt_store::init(cx);
487 let stdout_is_a_pty = false;
488 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
489 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
490
491 SettingsStore::update_global(cx, |store, cx| {
492 store.set_user_settings(include_str!("../runner_settings.json"), cx)
493 })
494 .unwrap();
495
496 Arc::new(AgentAppState {
497 languages,
498 client,
499 user_store,
500 fs,
501 node_runtime,
502 prompt_builder,
503 })
504}
505
506pub fn find_model(
507 model_name: &str,
508 model_registry: &LanguageModelRegistry,
509 cx: &App,
510) -> anyhow::Result<Arc<dyn LanguageModel>> {
511 let model = model_registry
512 .available_models(cx)
513 .find(|model| model.id().0 == model_name);
514
515 let Some(model) = model else {
516 return Err(anyhow!(
517 "No language model named {} was available. Available models: {}",
518 model_name,
519 model_registry
520 .available_models(cx)
521 .map(|model| model.id().0.clone())
522 .collect::<Vec<_>>()
523 .join(", ")
524 ));
525 };
526
527 Ok(model)
528}
529
530pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
531 (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
532}
533
534pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
535 futures::executor::block_on(async {
536 get_current_commit_id(repo_path).await.unwrap_or_default()
537 })
538}
539
540async fn run_judge_repetition(
541 example: Example,
542 model: Arc<dyn LanguageModel>,
543 run_output: &RunOutput,
544 round: u32,
545 cx: &AsyncApp,
546) -> Result<JudgeOutput> {
547 let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
548
549 if let Ok(judge_output) = &judge_result {
550 let cohort_id = example
551 .run_directory_path
552 .file_name()
553 .map(|name| name.to_string_lossy().to_string())
554 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
555
556 let path = std::path::Path::new(".");
557 let commit_id = get_current_commit_id(path).await.unwrap_or_default();
558
559 if let Some(thread) = &judge_output.thread {
560 telemetry::event!(
561 "Agent Eval Completed",
562 cohort_id = cohort_id,
563 example_name = example.name.clone(),
564 round = round,
565 diff_score = judge_output.diff.score,
566 diff_analysis = judge_output.diff.analysis,
567 thread_score = thread.score,
568 thread_analysis = thread.analysis,
569 tool_use_counts = run_output.tool_use_counts,
570 response_count = run_output.response_count,
571 token_usage = run_output.token_usage,
572 model = model.telemetry_id(),
573 model_provider = model.provider_id().to_string(),
574 repository_url = example.base.url.clone(),
575 repository_revision = example.base.revision.clone(),
576 diagnostics_before = run_output.diagnostics_before,
577 diagnostics_after = run_output.diagnostics_after,
578 commit_id = commit_id
579 );
580 } else {
581 telemetry::event!(
582 "Agent Eval Completed",
583 cohort_id = cohort_id,
584 example_name = example.name.clone(),
585 round = round,
586 diff_score = judge_output.diff.score,
587 diff_analysis = judge_output.diff.analysis,
588 tool_use_counts = run_output.tool_use_counts,
589 response_count = run_output.response_count,
590 token_usage = run_output.token_usage,
591 model = model.telemetry_id(),
592 model_provider = model.provider_id().to_string(),
593 repository_url = example.base.url.clone(),
594 repository_revision = example.base.revision.clone(),
595 diagnostics_before = run_output.diagnostics_before,
596 diagnostics_after = run_output.diagnostics_after,
597 commit_id = commit_id
598 );
599 }
600 }
601
602 judge_result
603}