1mod example;
2mod ids;
3
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6use telemetry;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use extension::ExtensionHostProxy;
12use futures::future;
13use futures::stream::StreamExt;
14use gpui::http_client::{Uri, read_proxy_from_env};
15use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
16use gpui_tokio::Tokio;
17use language::LanguageRegistry;
18use language_model::{
19 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
20};
21use node_runtime::{NodeBinaryOptions, NodeRuntime};
22use project::Project;
23use project::project_settings::ProjectSettings;
24use prompt_store::PromptBuilder;
25use release_channel::AppVersion;
26use reqwest_client::ReqwestClient;
27use settings::{Settings, SettingsStore};
28use std::collections::HashSet;
29use std::path::{Path, PathBuf};
30use std::sync::Arc;
31use std::usize;
32use util::ResultExt as _;
33
34pub const RUNS_DIR: &str = "./crates/eval/runs";
35
36#[derive(Parser, Debug)]
37#[command(name = "eval", disable_version_flag = true)]
38struct Args {
39 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
40 #[arg(value_name = "EXAMPLE_SUBSTRING")]
41 examples: Vec<String>,
42 /// Model to use (default: "claude-3-7-sonnet-latest")
43 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
44 model: String,
45 #[arg(long, value_delimiter = ',')]
46 languages: Option<Vec<String>>,
47 /// How many times to run each example. Note that this is currently not very efficient as N
48 /// worktrees will be created for the examples.
49 #[arg(long, default_value = "1")]
50 repetitions: u32,
51 /// How many times to run the judge on each example run.
52 #[arg(long, default_value = "3")]
53 judge_repetitions: u32,
54 /// Maximum number of examples to run concurrently.
55 #[arg(long, default_value = "10")]
56 concurrency: usize,
57}
58
59fn main() {
60 env_logger::init();
61
62 let args = Args::parse();
63 let all_available_examples = list_all_examples().unwrap();
64 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
65
66 let example_paths = all_available_examples
67 .iter()
68 .filter_map(|example_path| {
69 let name = example_path.file_name()?.to_string_lossy();
70 if args.examples.is_empty()
71 || args
72 .examples
73 .iter()
74 .any(|name_substring| name.contains(name_substring))
75 {
76 Some(example_path.clone())
77 } else {
78 None
79 }
80 })
81 .collect::<Vec<_>>();
82
83 let http_client = Arc::new(ReqwestClient::new());
84 let app = Application::headless().with_http_client(http_client.clone());
85
86 app.run(move |cx| {
87 let app_state = init(cx);
88
89 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
90 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
91 let session_id = uuid::Uuid::new_v4().to_string();
92
93 app_state
94 .client
95 .telemetry()
96 .start(system_id, installation_id, session_id, cx);
97
98 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
99
100 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
101 registry.set_default_model(Some(model.clone()), cx);
102 });
103
104 let model_provider_id = model.provider_id();
105
106 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
107
108 cx.spawn(async move |cx| {
109 authenticate.await.unwrap();
110
111 std::fs::create_dir_all(REPOS_DIR)?;
112 std::fs::create_dir_all(WORKTREES_DIR)?;
113
114 let run_dir = Path::new(RUNS_DIR).join(format!(
115 "{}",
116 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
117 ));
118 std::fs::create_dir_all(&run_dir)?;
119
120 let mut examples = Vec::new();
121
122 const COLORS: [&str; 12] = [
123 "\x1b[31m", // Red
124 "\x1b[32m", // Green
125 "\x1b[33m", // Yellow
126 "\x1b[34m", // Blue
127 "\x1b[35m", // Magenta
128 "\x1b[36m", // Cyan
129 "\x1b[91m", // Bright Red
130 "\x1b[92m", // Bright Green
131 "\x1b[93m", // Bright Yellow
132 "\x1b[94m", // Bright Blue
133 "\x1b[95m", // Bright Magenta
134 "\x1b[96m", // Bright Cyan
135 ];
136
137 let mut max_name_width = 0;
138 let mut skipped = Vec::new();
139
140 for example_path in &example_paths {
141 let example = Example::load_from_directory(example_path, &run_dir)?;
142
143 if !example
144 .base
145 .language_extension
146 .as_ref()
147 .map_or(false, |lang| languages.contains(lang))
148 {
149 skipped.push(example.name);
150 continue;
151 }
152
153 // TODO: This creates a worktree per repetition. Ideally these examples should
154 // either be run sequentially on the same worktree, or reuse worktrees when there
155 // are more examples to run than the concurrency limit.
156 for repetition_number in 0..args.repetitions {
157 let mut example = example.clone();
158 example.set_repetition_number(repetition_number);
159
160 let name_len = example.name.len();
161 if name_len > max_name_width {
162 max_name_width = example.name.len();
163 }
164
165 examples.push(example);
166 }
167 }
168
169 println!("Skipped examples: {}\n", skipped.join(", "));
170
171 if examples.is_empty() {
172 eprintln!("Filter matched no examples");
173 return cx.update(|cx| cx.quit());
174 }
175
176 let mut repo_urls = HashSet::new();
177 let mut clone_tasks = Vec::new();
178
179 for (i, example) in examples.iter_mut().enumerate() {
180 let color = COLORS[i % COLORS.len()].to_string();
181 example.set_log_prefix_style(&color, max_name_width);
182
183 println!(
184 "{}Logging to: {}",
185 example.log_prefix,
186 example.output_file_path.display()
187 );
188
189 let repo_url = example.base.url.clone();
190 if repo_urls.insert(repo_url.clone()) {
191 let repo_path = repo_path_for_url(&repo_url);
192
193 if !repo_path.join(".git").is_dir() {
194 println!(
195 "{:<width$} < {}",
196 "↓ Cloning",
197 repo_url,
198 width = max_name_width
199 );
200
201 let git_task = cx.spawn(async move |_cx| {
202 std::fs::create_dir_all(&repo_path)?;
203 run_git(&repo_path, &["init"]).await?;
204 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
205 });
206
207 clone_tasks.push(git_task);
208 } else {
209 println!(
210 "{:<width$} < {}",
211 "✔︎ Already cloned",
212 repo_url,
213 width = max_name_width
214 );
215
216 let actual_origin =
217 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
218 if actual_origin != repo_url {
219 return Err(anyhow!(
220 "remote origin {} does not match expected origin {}",
221 actual_origin,
222 repo_url,
223 ));
224 }
225 }
226 }
227 }
228
229 future::join_all(clone_tasks).await;
230
231 for example in examples.iter_mut() {
232 example.setup().await?;
233 }
234
235 let judge_repetitions = args.judge_repetitions;
236 let concurrency = args.concurrency;
237
238 let tasks = examples
239 .into_iter()
240 .map(|example| {
241 let app_state = app_state.clone();
242 let model = model.clone();
243 cx.spawn(async move |cx| {
244 let result =
245 run_example(&example, model, app_state, judge_repetitions, cx).await;
246 (result, example)
247 })
248 })
249 .collect::<Vec<_>>();
250
251 let results = futures::stream::iter(tasks)
252 .buffer_unordered(concurrency)
253 .collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
254 .await;
255
256 println!("\n\n");
257 println!("========================================");
258 println!(" EVAL RESULTS ");
259 println!("========================================");
260 println!("");
261
262 let mut judge_scores = Vec::new();
263
264 for (result, example) in results {
265 match result {
266 Err(err) => {
267 println!("💥 {}{:?}", example.log_prefix, err);
268 }
269 Ok(judge_results) => {
270 for judge_result in judge_results {
271 match judge_result {
272 Ok(judge_output) => {
273 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
274 let score: u32 = judge_output.score;
275 let score_index = (score.min(5)) as usize;
276
277 println!(
278 "{} {}{}",
279 SCORES[score_index], example.log_prefix, judge_output.score,
280 );
281 judge_scores.push(judge_output.score);
282 }
283 Err(err) => {
284 println!("💥 {}{:?}", example.log_prefix, err);
285 }
286 }
287 }
288 }
289 }
290 println!(
291 "{} > {}",
292 " ".repeat(max_name_width),
293 example.output_file_path.display()
294 );
295 }
296
297 let score_count = judge_scores.len();
298 let average_score = judge_scores
299 .into_iter()
300 .map(|score| score as f32)
301 .sum::<f32>()
302 / (score_count as f32);
303 println!("\nAverage score: {average_score}");
304
305 std::thread::sleep(std::time::Duration::from_secs(2));
306
307 app_state.client.telemetry().flush_events();
308
309 cx.update(|cx| cx.quit())
310 })
311 .detach_and_log_err(cx);
312 });
313}
314
315async fn run_example(
316 example: &Example,
317 model: Arc<dyn LanguageModel>,
318 app_state: Arc<AgentAppState>,
319 judge_repetitions: u32,
320 cx: &mut AsyncApp,
321) -> Result<Vec<Result<JudgeOutput>>> {
322 let run_output = cx
323 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
324 .await?;
325 let diff = example.repository_diff().await?;
326
327 // Run judge for each repetition
328 let mut results = Vec::new();
329 for round in 0..judge_repetitions {
330 let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
331
332 if let Ok(judge_output) = &judge_result {
333 let cohort_id = example
334 .output_file_path
335 .parent()
336 .and_then(|p| p.file_name())
337 .map(|name| name.to_string_lossy().to_string())
338 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
339
340 let path = std::path::Path::new(".");
341 let commit_id = get_current_commit_id(path).await.unwrap_or_default();
342
343 telemetry::event!(
344 "Agent Eval Completed",
345 cohort_id = cohort_id,
346 example_name = example.name.clone(),
347 round = round,
348 score = judge_output.score,
349 analysis = judge_output.analysis,
350 tool_use_counts = run_output.tool_use_counts,
351 response_count = run_output.response_count,
352 token_usage = run_output.token_usage,
353 model = model.telemetry_id(),
354 model_provider = model.provider_id().to_string(),
355 repository_url = example.base.url.clone(),
356 repository_revision = example.base.revision.clone(),
357 diagnostics_summary = run_output.diagnostics,
358 commit_id = commit_id
359 );
360 }
361
362 results.push(judge_result);
363 }
364
365 app_state.client.telemetry().flush_events();
366
367 Ok(results)
368}
369
370fn list_all_examples() -> Result<Vec<PathBuf>> {
371 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
372 let entries = std::fs::read_dir(path).unwrap();
373 let mut result_paths = Vec::new();
374 for entry in entries {
375 let entry = entry?;
376 let path = entry.path();
377 if path.is_dir() {
378 result_paths.push(path);
379 }
380 }
381 Ok(result_paths)
382}
383
384/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
385pub struct AgentAppState {
386 pub languages: Arc<LanguageRegistry>,
387 pub client: Arc<Client>,
388 pub user_store: Entity<UserStore>,
389 pub fs: Arc<dyn fs::Fs>,
390 pub node_runtime: NodeRuntime,
391
392 // Additional fields not present in `workspace::AppState`.
393 pub prompt_builder: Arc<PromptBuilder>,
394}
395
396pub fn init(cx: &mut App) -> Arc<AgentAppState> {
397 release_channel::init(SemanticVersion::default(), cx);
398 gpui_tokio::init(cx);
399
400 let mut settings_store = SettingsStore::new(cx);
401 settings_store
402 .set_default_settings(settings::default_settings().as_ref(), cx)
403 .unwrap();
404 cx.set_global(settings_store);
405 client::init_settings(cx);
406
407 // Set User-Agent so we can download language servers from GitHub
408 let user_agent = format!(
409 "Zed/{} ({}; {})",
410 AppVersion::global(cx),
411 std::env::consts::OS,
412 std::env::consts::ARCH
413 );
414 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
415 let proxy_url = proxy_str
416 .as_ref()
417 .and_then(|input| input.parse::<Uri>().ok())
418 .or_else(read_proxy_from_env);
419 let http = {
420 let _guard = Tokio::handle(cx).enter();
421
422 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
423 .expect("could not start HTTP client")
424 };
425 cx.set_http_client(Arc::new(http));
426
427 Project::init_settings(cx);
428
429 let client = Client::production(cx);
430 cx.set_http_client(client.http_client().clone());
431
432 let git_binary_path = None;
433 let fs = Arc::new(RealFs::new(
434 git_binary_path,
435 cx.background_executor().clone(),
436 ));
437
438 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
439 languages.set_language_server_download_dir(paths::languages_dir().clone());
440 let languages = Arc::new(languages);
441
442 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
443
444 extension::init(cx);
445
446 let (tx, rx) = async_watch::channel(None);
447 cx.observe_global::<SettingsStore>(move |cx| {
448 let settings = &ProjectSettings::get_global(cx).node;
449 let options = NodeBinaryOptions {
450 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
451 allow_binary_download: true,
452 use_paths: settings.path.as_ref().map(|node_path| {
453 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
454 let npm_path = settings
455 .npm_path
456 .as_ref()
457 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
458 (
459 node_path.clone(),
460 npm_path.unwrap_or_else(|| {
461 let base_path = PathBuf::new();
462 node_path.parent().unwrap_or(&base_path).join("npm")
463 }),
464 )
465 }),
466 };
467 tx.send(Some(options)).log_err();
468 })
469 .detach();
470 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
471
472 let extension_host_proxy = ExtensionHostProxy::global(cx);
473
474 language::init(cx);
475 language_extension::init(extension_host_proxy.clone(), languages.clone());
476 language_model::init(client.clone(), cx);
477 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
478 languages::init(languages.clone(), node_runtime.clone(), cx);
479 assistant_tools::init(client.http_client().clone(), cx);
480 context_server::init(cx);
481 let stdout_is_a_pty = false;
482 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
483 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
484
485 SettingsStore::update_global(cx, |store, cx| {
486 store.set_user_settings(include_str!("../runner_settings.json"), cx)
487 })
488 .unwrap();
489
490 Arc::new(AgentAppState {
491 languages,
492 client,
493 user_store,
494 fs,
495 node_runtime,
496 prompt_builder,
497 })
498}
499
500pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
501 let model_registry = LanguageModelRegistry::read_global(cx);
502 let model = model_registry
503 .available_models(cx)
504 .find(|model| model.id().0 == model_name);
505
506 let Some(model) = model else {
507 return Err(anyhow!(
508 "No language model named {} was available. Available models: {}",
509 model_name,
510 model_registry
511 .available_models(cx)
512 .map(|model| model.id().0.clone())
513 .collect::<Vec<_>>()
514 .join(", ")
515 ));
516 };
517
518 Ok(model)
519}
520
521pub fn authenticate_model_provider(
522 provider_id: LanguageModelProviderId,
523 cx: &mut App,
524) -> Task<std::result::Result<(), AuthenticateError>> {
525 let model_registry = LanguageModelRegistry::read_global(cx);
526 let model_provider = model_registry.provider(&provider_id).unwrap();
527 model_provider.authenticate(cx)
528}
529
530pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
531 (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
532}
533
534pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
535 futures::executor::block_on(async {
536 get_current_commit_id(repo_path).await.unwrap_or_default()
537 })
538}