1mod example;
2mod ids;
3
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6use telemetry;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use extension::ExtensionHostProxy;
12use futures::{StreamExt, future};
13use gpui::http_client::{Uri, read_proxy_from_env};
14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
15use gpui_tokio::Tokio;
16use language::LanguageRegistry;
17use language_model::{
18 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
19};
20use node_runtime::{NodeBinaryOptions, NodeRuntime};
21use project::Project;
22use project::project_settings::ProjectSettings;
23use prompt_store::PromptBuilder;
24use release_channel::AppVersion;
25use reqwest_client::ReqwestClient;
26use settings::{Settings, SettingsStore};
27use std::collections::HashSet;
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30use std::usize;
31use util::ResultExt as _;
32
33pub const RUNS_DIR: &str = "./crates/eval/runs";
34
35#[derive(Parser, Debug)]
36#[command(name = "eval", disable_version_flag = true)]
37struct Args {
38 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
39 #[arg(value_name = "EXAMPLE_SUBSTRING")]
40 examples: Vec<String>,
41 /// Model to use (default: "claude-3-7-sonnet-latest")
42 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
43 model: String,
44 #[arg(long, value_delimiter = ',')]
45 languages: Option<Vec<String>>,
46 /// How many times to run each example. Note that this is currently not very efficient as N
47 /// worktrees will be created for the examples.
48 #[arg(long, default_value = "1")]
49 repetitions: u32,
50 /// How many times to run the judge on each example run.
51 #[arg(long, default_value = "3")]
52 judge_repetitions: u32,
53 /// Maximum number of examples to run concurrently.
54 #[arg(long, default_value = "10")]
55 concurrency: usize,
56}
57
58fn main() {
59 env_logger::init();
60
61 let args = Args::parse();
62 let all_available_examples = list_all_examples().unwrap();
63 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
64
65 let example_paths = all_available_examples
66 .iter()
67 .filter_map(|example_path| {
68 let name = example_path.file_name()?.to_string_lossy();
69 if args.examples.is_empty()
70 || args
71 .examples
72 .iter()
73 .any(|name_substring| name.contains(name_substring))
74 {
75 Some(example_path.clone())
76 } else {
77 None
78 }
79 })
80 .collect::<Vec<_>>();
81
82 let http_client = Arc::new(ReqwestClient::new());
83 let app = Application::headless().with_http_client(http_client.clone());
84
85 app.run(move |cx| {
86 let app_state = init(cx);
87
88 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
89 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
90 let session_id = uuid::Uuid::new_v4().to_string();
91
92 app_state
93 .client
94 .telemetry()
95 .start(system_id, installation_id, session_id, cx);
96
97 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
98
99 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
100 registry.set_default_model(Some(model.clone()), cx);
101 });
102
103 let model_provider_id = model.provider_id();
104
105 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
106
107 cx.spawn(async move |cx| {
108 authenticate.await.unwrap();
109
110 std::fs::create_dir_all(REPOS_DIR)?;
111 std::fs::create_dir_all(WORKTREES_DIR)?;
112
113 let run_dir = Path::new(RUNS_DIR).join(format!(
114 "{}",
115 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
116 ));
117 std::fs::create_dir_all(&run_dir)?;
118
119 let mut examples = Vec::new();
120
121 const COLORS: [&str; 12] = [
122 "\x1b[31m", // Red
123 "\x1b[32m", // Green
124 "\x1b[33m", // Yellow
125 "\x1b[34m", // Blue
126 "\x1b[35m", // Magenta
127 "\x1b[36m", // Cyan
128 "\x1b[91m", // Bright Red
129 "\x1b[92m", // Bright Green
130 "\x1b[93m", // Bright Yellow
131 "\x1b[94m", // Bright Blue
132 "\x1b[95m", // Bright Magenta
133 "\x1b[96m", // Bright Cyan
134 ];
135
136 let mut max_name_width = 0;
137 let mut skipped = Vec::new();
138
139 for example_path in &example_paths {
140 let example = Example::load_from_directory(example_path, &run_dir)?;
141
142 if !example
143 .base
144 .language_extension
145 .as_ref()
146 .map_or(false, |lang| languages.contains(lang))
147 {
148 skipped.push(example.name);
149 continue;
150 }
151
152 // TODO: This creates a worktree per repetition. Ideally these examples should
153 // either be run sequentially on the same worktree, or reuse worktrees when there
154 // are more examples to run than the concurrency limit.
155 for repetition_number in 0..args.repetitions {
156 let mut example = example.clone();
157 example.set_repetition_number(repetition_number);
158
159 let name_len = example.name.len();
160 if name_len > max_name_width {
161 max_name_width = example.name.len();
162 }
163
164 examples.push(example);
165 }
166 }
167
168 println!("Skipped examples: {}\n", skipped.join(", "));
169
170 if examples.is_empty() {
171 eprintln!("Filter matched no examples");
172 return cx.update(|cx| cx.quit());
173 }
174
175 let mut repo_urls = HashSet::new();
176 let mut clone_tasks = Vec::new();
177
178 for (i, example) in examples.iter_mut().enumerate() {
179 let color = COLORS[i % COLORS.len()].to_string();
180 example.set_log_prefix_style(&color, max_name_width);
181
182 println!(
183 "{}Logging to: {}",
184 example.log_prefix,
185 example.example_output_directory().display()
186 );
187
188 let repo_url = example.base.url.clone();
189 if repo_urls.insert(repo_url.clone()) {
190 let repo_path = repo_path_for_url(&repo_url);
191
192 if !repo_path.join(".git").is_dir() {
193 println!(
194 "{:<width$} < {}",
195 "↓ Cloning",
196 repo_url,
197 width = max_name_width
198 );
199
200 let git_task = cx.spawn(async move |_cx| {
201 std::fs::create_dir_all(&repo_path)?;
202 run_git(&repo_path, &["init"]).await?;
203 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
204 });
205
206 clone_tasks.push(git_task);
207 } else {
208 println!(
209 "{:<width$} < {}",
210 "✔︎ Already cloned",
211 repo_url,
212 width = max_name_width
213 );
214
215 let actual_origin =
216 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
217 if actual_origin != repo_url {
218 return Err(anyhow!(
219 "remote origin {} does not match expected origin {}",
220 actual_origin,
221 repo_url,
222 ));
223 }
224 }
225 }
226 }
227
228 future::join_all(clone_tasks).await;
229
230 for example in examples.iter_mut() {
231 example.setup().await?;
232 }
233
234 let judge_repetitions = args.judge_repetitions;
235 let concurrency = args.concurrency;
236
237 let tasks = examples.iter().map(|example| {
238 let app_state = app_state.clone();
239 let model = model.clone();
240 let example = example.clone();
241 cx.spawn(async move |cx| {
242 let result =
243 run_example(&example, model, app_state, judge_repetitions, cx).await;
244 (result, example)
245 })
246 });
247
248 let results = futures::stream::iter(tasks)
249 .buffer_unordered(concurrency)
250 .collect::<Vec<_>>()
251 .await;
252
253 println!("\n\n");
254 println!("========================================");
255 println!(" EVAL RESULTS ");
256 println!("========================================");
257 println!("");
258
259 let mut diff_scores = Vec::new();
260 let mut thread_scores = Vec::new();
261 let mut error_count = 0;
262
263 for (result, example) in results {
264 match result {
265 Err(err) => {
266 println!("💥 {}{:?}", example.log_prefix, err);
267 error_count += 1;
268 }
269 Ok(judge_results) => {
270 for judge_result in judge_results {
271 match judge_result {
272 Ok(judge_output) => {
273 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
274 let diff_score: u32 = judge_output.diff.score;
275 let score_index = (diff_score.min(5)) as usize;
276
277 println!(
278 "{} {}{} (Diff)",
279 SCORES[score_index],
280 example.log_prefix,
281 judge_output.diff.score,
282 );
283 diff_scores.push(judge_output.diff.score);
284
285 if let Some(thread) = judge_output.thread {
286 let process_score: u32 = thread.score;
287 let score_index = (process_score.min(5)) as usize;
288 println!(
289 "{} {}{} (Thread)",
290 SCORES[score_index], example.log_prefix, thread.score,
291 );
292 thread_scores.push(thread.score);
293 }
294 }
295 Err(err) => {
296 println!("💥 {}{:?}", example.log_prefix, err);
297 }
298 }
299 }
300 }
301 }
302 println!(
303 "{} > {}",
304 " ".repeat(max_name_width),
305 example.example_output_directory().display()
306 );
307 }
308
309 let diff_score_count = diff_scores.len();
310 let average_diff_score = diff_scores
311 .into_iter()
312 .map(|score| score as f32)
313 .sum::<f32>()
314 / (diff_score_count as f32);
315
316 if error_count > 0 {
317 println!("\n{error_count} examples failed to run!");
318 }
319
320 if diff_score_count > 0 {
321 println!("\nAverage code diff score: {average_diff_score}");
322 }
323
324 let thread_score_count = thread_scores.len();
325
326 // We might have gotten no thread scores if we weren't asked to judge the thread.
327 if thread_score_count > 0 {
328 let average_thread_score = thread_scores
329 .into_iter()
330 .map(|score| score as f32)
331 .sum::<f32>()
332 / (thread_score_count as f32);
333
334 if diff_score_count > 0 {
335 println!("\nAverage thread score: {average_thread_score}");
336 }
337 }
338
339 std::thread::sleep(std::time::Duration::from_secs(2));
340
341 app_state.client.telemetry().flush_events();
342
343 cx.update(|cx| cx.quit())
344 })
345 .detach_and_log_err(cx);
346 });
347}
348
349async fn run_example(
350 example: &Example,
351 model: Arc<dyn LanguageModel>,
352 app_state: Arc<AgentAppState>,
353 judge_repetitions: u32,
354 cx: &mut AsyncApp,
355) -> Result<Vec<Result<JudgeOutput>>> {
356 let run_output = cx
357 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
358 .await?;
359
360 let judge_tasks = (0..judge_repetitions)
361 .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
362
363 let results = future::join_all(judge_tasks).await;
364
365 app_state.client.telemetry().flush_events();
366
367 Ok(results)
368}
369
370fn list_all_examples() -> Result<Vec<PathBuf>> {
371 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
372 let entries = std::fs::read_dir(path).unwrap();
373 let mut result_paths = Vec::new();
374 for entry in entries {
375 let entry = entry?;
376 let path = entry.path();
377 if path.is_dir() {
378 result_paths.push(path);
379 }
380 }
381 Ok(result_paths)
382}
383
384/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
385pub struct AgentAppState {
386 pub languages: Arc<LanguageRegistry>,
387 pub client: Arc<Client>,
388 pub user_store: Entity<UserStore>,
389 pub fs: Arc<dyn fs::Fs>,
390 pub node_runtime: NodeRuntime,
391
392 // Additional fields not present in `workspace::AppState`.
393 pub prompt_builder: Arc<PromptBuilder>,
394}
395
396pub fn init(cx: &mut App) -> Arc<AgentAppState> {
397 release_channel::init(SemanticVersion::default(), cx);
398 gpui_tokio::init(cx);
399
400 let mut settings_store = SettingsStore::new(cx);
401 settings_store
402 .set_default_settings(settings::default_settings().as_ref(), cx)
403 .unwrap();
404 cx.set_global(settings_store);
405 client::init_settings(cx);
406
407 // Set User-Agent so we can download language servers from GitHub
408 let user_agent = format!(
409 "Zed/{} ({}; {})",
410 AppVersion::global(cx),
411 std::env::consts::OS,
412 std::env::consts::ARCH
413 );
414 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
415 let proxy_url = proxy_str
416 .as_ref()
417 .and_then(|input| input.parse::<Uri>().ok())
418 .or_else(read_proxy_from_env);
419 let http = {
420 let _guard = Tokio::handle(cx).enter();
421
422 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
423 .expect("could not start HTTP client")
424 };
425 cx.set_http_client(Arc::new(http));
426
427 Project::init_settings(cx);
428
429 let client = Client::production(cx);
430 cx.set_http_client(client.http_client().clone());
431
432 let git_binary_path = None;
433 let fs = Arc::new(RealFs::new(
434 git_binary_path,
435 cx.background_executor().clone(),
436 ));
437
438 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
439 languages.set_language_server_download_dir(paths::languages_dir().clone());
440 let languages = Arc::new(languages);
441
442 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
443
444 extension::init(cx);
445
446 let (tx, rx) = async_watch::channel(None);
447 cx.observe_global::<SettingsStore>(move |cx| {
448 let settings = &ProjectSettings::get_global(cx).node;
449 let options = NodeBinaryOptions {
450 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
451 allow_binary_download: true,
452 use_paths: settings.path.as_ref().map(|node_path| {
453 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
454 let npm_path = settings
455 .npm_path
456 .as_ref()
457 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
458 (
459 node_path.clone(),
460 npm_path.unwrap_or_else(|| {
461 let base_path = PathBuf::new();
462 node_path.parent().unwrap_or(&base_path).join("npm")
463 }),
464 )
465 }),
466 };
467 tx.send(Some(options)).log_err();
468 })
469 .detach();
470 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
471
472 let extension_host_proxy = ExtensionHostProxy::global(cx);
473
474 language::init(cx);
475 language_extension::init(extension_host_proxy.clone(), languages.clone());
476 language_model::init(client.clone(), cx);
477 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
478 languages::init(languages.clone(), node_runtime.clone(), cx);
479 assistant_tools::init(client.http_client().clone(), cx);
480 context_server::init(cx);
481 prompt_store::init(cx);
482 let stdout_is_a_pty = false;
483 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
484 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
485
486 SettingsStore::update_global(cx, |store, cx| {
487 store.set_user_settings(include_str!("../runner_settings.json"), cx)
488 })
489 .unwrap();
490
491 Arc::new(AgentAppState {
492 languages,
493 client,
494 user_store,
495 fs,
496 node_runtime,
497 prompt_builder,
498 })
499}
500
501pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
502 let model_registry = LanguageModelRegistry::read_global(cx);
503 let model = model_registry
504 .available_models(cx)
505 .find(|model| model.id().0 == model_name);
506
507 let Some(model) = model else {
508 return Err(anyhow!(
509 "No language model named {} was available. Available models: {}",
510 model_name,
511 model_registry
512 .available_models(cx)
513 .map(|model| model.id().0.clone())
514 .collect::<Vec<_>>()
515 .join(", ")
516 ));
517 };
518
519 Ok(model)
520}
521
522pub fn authenticate_model_provider(
523 provider_id: LanguageModelProviderId,
524 cx: &mut App,
525) -> Task<std::result::Result<(), AuthenticateError>> {
526 let model_registry = LanguageModelRegistry::read_global(cx);
527 let model_provider = model_registry.provider(&provider_id).unwrap();
528 model_provider.authenticate(cx)
529}
530
531pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
532 (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
533}
534
535pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
536 futures::executor::block_on(async {
537 get_current_commit_id(repo_path).await.unwrap_or_default()
538 })
539}
540
541async fn run_judge_repetition(
542 example: Example,
543 model: Arc<dyn LanguageModel>,
544 run_output: &RunOutput,
545 round: u32,
546 cx: &AsyncApp,
547) -> Result<JudgeOutput> {
548 let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
549
550 if let Ok(judge_output) = &judge_result {
551 let cohort_id = example
552 .run_directory_path
553 .file_name()
554 .map(|name| name.to_string_lossy().to_string())
555 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
556
557 let path = std::path::Path::new(".");
558 let commit_id = get_current_commit_id(path).await.unwrap_or_default();
559
560 if let Some(thread) = &judge_output.thread {
561 telemetry::event!(
562 "Agent Eval Completed",
563 cohort_id = cohort_id,
564 example_name = example.name.clone(),
565 round = round,
566 diff_score = judge_output.diff.score,
567 diff_analysis = judge_output.diff.analysis,
568 thread_score = thread.score,
569 thread_analysis = thread.analysis,
570 tool_use_counts = run_output.tool_use_counts,
571 response_count = run_output.response_count,
572 token_usage = run_output.token_usage,
573 model = model.telemetry_id(),
574 model_provider = model.provider_id().to_string(),
575 repository_url = example.base.url.clone(),
576 repository_revision = example.base.revision.clone(),
577 diagnostics_before = run_output.diagnostics_before,
578 diagnostics_after = run_output.diagnostics_after,
579 commit_id = commit_id
580 );
581 } else {
582 telemetry::event!(
583 "Agent Eval Completed",
584 cohort_id = cohort_id,
585 example_name = example.name.clone(),
586 round = round,
587 diff_score = judge_output.diff.score,
588 diff_analysis = judge_output.diff.analysis,
589 tool_use_counts = run_output.tool_use_counts,
590 response_count = run_output.response_count,
591 token_usage = run_output.token_usage,
592 model = model.telemetry_id(),
593 model_provider = model.provider_id().to_string(),
594 repository_url = example.base.url.clone(),
595 repository_revision = example.base.revision.clone(),
596 diagnostics_before = run_output.diagnostics_before,
597 diagnostics_after = run_output.diagnostics_after,
598 commit_id = commit_id
599 );
600 }
601 }
602
603 judge_result
604}