1mod example;
2mod ids;
3
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6use telemetry;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use extension::ExtensionHostProxy;
12use futures::future;
13use futures::stream::StreamExt;
14use gpui::http_client::{Uri, read_proxy_from_env};
15use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
16use gpui_tokio::Tokio;
17use language::LanguageRegistry;
18use language_model::{
19 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
20};
21use node_runtime::{NodeBinaryOptions, NodeRuntime};
22use project::Project;
23use project::project_settings::ProjectSettings;
24use prompt_store::PromptBuilder;
25use release_channel::AppVersion;
26use reqwest_client::ReqwestClient;
27use settings::{Settings, SettingsStore};
28use std::collections::HashSet;
29use std::path::{Path, PathBuf};
30use std::sync::Arc;
31use std::usize;
32use util::ResultExt as _;
33
34pub const RUNS_DIR: &str = "./crates/eval/runs";
35
36#[derive(Parser, Debug)]
37#[command(name = "eval", disable_version_flag = true)]
38struct Args {
39 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
40 #[arg(value_name = "EXAMPLE_SUBSTRING")]
41 examples: Vec<String>,
42 /// Model to use (default: "claude-3-7-sonnet-latest")
43 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
44 model: String,
45 #[arg(long, value_delimiter = ',')]
46 languages: Option<Vec<String>>,
47 /// How many times to run the judge on each example run.
48 #[arg(long, default_value = "3")]
49 judge_repetitions: u32,
50 /// Maximum number of examples to run concurrently.
51 #[arg(long, default_value = "10")]
52 concurrency: usize,
53}
54
55fn main() {
56 env_logger::init();
57
58 let args = Args::parse();
59 let all_available_examples = list_all_examples().unwrap();
60 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
61
62 let example_paths = all_available_examples
63 .iter()
64 .filter_map(|example_path| {
65 let name = example_path.file_name()?.to_string_lossy();
66 if args.examples.is_empty()
67 || args
68 .examples
69 .iter()
70 .any(|name_substring| name.contains(name_substring))
71 {
72 Some(example_path.clone())
73 } else {
74 None
75 }
76 })
77 .collect::<Vec<_>>();
78
79 let http_client = Arc::new(ReqwestClient::new());
80 let app = Application::headless().with_http_client(http_client.clone());
81
82 app.run(move |cx| {
83 let app_state = init(cx);
84
85 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
86 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
87 let session_id = uuid::Uuid::new_v4().to_string();
88
89 app_state
90 .client
91 .telemetry()
92 .start(system_id, installation_id, session_id, cx);
93
94 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
95
96 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
97 registry.set_default_model(Some(model.clone()), cx);
98 });
99
100 let model_provider_id = model.provider_id();
101
102 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
103
104 cx.spawn(async move |cx| {
105 authenticate.await.unwrap();
106
107 std::fs::create_dir_all(REPOS_DIR)?;
108 std::fs::create_dir_all(WORKTREES_DIR)?;
109
110 let run_dir = Path::new(RUNS_DIR).join(format!(
111 "{}",
112 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
113 ));
114 std::fs::create_dir_all(&run_dir)?;
115
116 let mut examples = Vec::new();
117
118 const COLORS: [&str; 12] = [
119 "\x1b[31m", // Red
120 "\x1b[32m", // Green
121 "\x1b[33m", // Yellow
122 "\x1b[34m", // Blue
123 "\x1b[35m", // Magenta
124 "\x1b[36m", // Cyan
125 "\x1b[91m", // Bright Red
126 "\x1b[92m", // Bright Green
127 "\x1b[93m", // Bright Yellow
128 "\x1b[94m", // Bright Blue
129 "\x1b[95m", // Bright Magenta
130 "\x1b[96m", // Bright Cyan
131 ];
132
133 let mut max_name_width = 0;
134 let mut skipped = Vec::new();
135
136 for example_path in &example_paths {
137 let example = Example::load_from_directory(example_path, &run_dir)?;
138
139 if !example
140 .base
141 .language_extension
142 .as_ref()
143 .map_or(false, |lang| languages.contains(lang))
144 {
145 skipped.push(example.name);
146 continue;
147 }
148
149 let name_len = example.name.len();
150 if name_len > max_name_width {
151 max_name_width = example.name.len();
152 }
153
154 examples.push(example);
155 }
156
157 println!("Skipped examples: {}\n", skipped.join(", "));
158
159 if examples.is_empty() {
160 eprintln!("Filter matched no examples");
161 return cx.update(|cx| cx.quit());
162 }
163
164 let mut repo_urls = HashSet::new();
165 let mut clone_tasks = Vec::new();
166
167 for (i, example) in examples.iter_mut().enumerate() {
168 let color = COLORS[i % COLORS.len()].to_string();
169 example.set_log_prefix_style(&color, max_name_width);
170
171 println!(
172 "{}Logging to: {}",
173 example.log_prefix,
174 example.output_file_path.display()
175 );
176
177 let repo_url = example.base.url.clone();
178 if repo_urls.insert(repo_url.clone()) {
179 let repo_path = repo_path_for_url(&repo_url);
180
181 if !repo_path.join(".git").is_dir() {
182 println!(
183 "{:<width$} < {}",
184 "↓ Cloning",
185 repo_url,
186 width = max_name_width
187 );
188
189 let git_task = cx.spawn(async move |_cx| {
190 std::fs::create_dir_all(&repo_path)?;
191 run_git(&repo_path, &["init"]).await?;
192 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
193 });
194
195 clone_tasks.push(git_task);
196 } else {
197 println!(
198 "{:<width$} < {}",
199 "✔︎ Already cloned",
200 repo_url,
201 width = max_name_width
202 );
203
204 let actual_origin =
205 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
206 if actual_origin != repo_url {
207 return Err(anyhow!(
208 "remote origin {} does not match expected origin {}",
209 actual_origin,
210 repo_url,
211 ));
212 }
213 }
214 }
215 }
216
217 future::join_all(clone_tasks).await;
218
219 for example in examples.iter_mut() {
220 example.setup().await?;
221 }
222
223 let judge_repetitions = args.judge_repetitions;
224 let concurrency = args.concurrency;
225
226 let tasks = examples
227 .into_iter()
228 .map(|example| {
229 let app_state = app_state.clone();
230 let model = model.clone();
231 cx.spawn(async move |cx| {
232 let result =
233 run_example(&example, model, app_state, judge_repetitions, cx).await;
234 (result, example)
235 })
236 })
237 .collect::<Vec<_>>();
238
239 let results = futures::stream::iter(tasks)
240 .buffer_unordered(concurrency)
241 .collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
242 .await;
243
244 println!("\n\n");
245 println!("========================================");
246 println!(" EVAL RESULTS ");
247 println!("========================================");
248 println!("");
249
250 let mut judge_scores = Vec::new();
251
252 for (result, example) in results {
253 match result {
254 Err(err) => {
255 println!("💥 {}{:?}", example.log_prefix, err);
256 }
257 Ok(judge_results) => {
258 for judge_result in judge_results {
259 match judge_result {
260 Ok(judge_output) => {
261 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
262
263 println!(
264 "{} {}{}",
265 SCORES[judge_output.score.min(5) as usize],
266 example.log_prefix,
267 judge_output.score,
268 );
269 judge_scores.push(judge_output.score);
270 }
271 Err(err) => {
272 println!("💥 {}{:?}", example.log_prefix, err);
273 }
274 }
275 }
276 }
277 }
278 println!(
279 "{} > {}",
280 " ".repeat(max_name_width),
281 example.output_file_path.display()
282 );
283 }
284
285 let score_count = judge_scores.len();
286 let average_score = judge_scores
287 .into_iter()
288 .map(|score| score as f32)
289 .sum::<f32>()
290 / (score_count as f32);
291 println!("\nAverage score: {average_score}");
292
293 std::thread::sleep(std::time::Duration::from_secs(2));
294
295 // Flush telemetry events before exiting
296 app_state.client.telemetry().flush_events();
297
298 cx.update(|cx| cx.quit())
299 })
300 .detach_and_log_err(cx);
301 });
302}
303
304async fn run_example(
305 example: &Example,
306 model: Arc<dyn LanguageModel>,
307 app_state: Arc<AgentAppState>,
308 judge_repetitions: u32,
309 cx: &mut AsyncApp,
310) -> Result<Vec<Result<JudgeOutput>>> {
311 let run_output = cx
312 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
313 .await?;
314 let diff = example.repository_diff().await?;
315
316 // Run judge for each repetition
317 let mut results = Vec::new();
318 for round in 0..judge_repetitions {
319 let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
320
321 // Log telemetry for this judge result
322 if let Ok(judge_output) = &judge_result {
323 let cohort_id = example
324 .output_file_path
325 .parent()
326 .and_then(|p| p.file_name())
327 .map(|name| name.to_string_lossy().to_string())
328 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
329
330 telemetry::event!(
331 "Agent Eval Completed",
332 cohort_id = cohort_id,
333 example_name = example.name.clone(),
334 round = round,
335 score = judge_output.score,
336 analysis = judge_output.analysis,
337 tool_use_counts = run_output.tool_use_counts,
338 response_count = run_output.response_count,
339 token_usage = run_output.token_usage,
340 model = model.telemetry_id(),
341 model_provider = model.provider_id().to_string(),
342 repository_url = example.base.url.clone(),
343 repository_revision = example.base.revision.clone(),
344 diagnostics_summary = run_output.diagnostics
345 );
346 }
347
348 results.push(judge_result);
349 }
350
351 app_state.client.telemetry().flush_events();
352
353 Ok(results)
354}
355
356fn list_all_examples() -> Result<Vec<PathBuf>> {
357 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
358 let entries = std::fs::read_dir(path).unwrap();
359 let mut result_paths = Vec::new();
360 for entry in entries {
361 let entry = entry?;
362 let path = entry.path();
363 if path.is_dir() {
364 result_paths.push(path);
365 }
366 }
367 Ok(result_paths)
368}
369
370/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
371pub struct AgentAppState {
372 pub languages: Arc<LanguageRegistry>,
373 pub client: Arc<Client>,
374 pub user_store: Entity<UserStore>,
375 pub fs: Arc<dyn fs::Fs>,
376 pub node_runtime: NodeRuntime,
377
378 // Additional fields not present in `workspace::AppState`.
379 pub prompt_builder: Arc<PromptBuilder>,
380}
381
382pub fn init(cx: &mut App) -> Arc<AgentAppState> {
383 release_channel::init(SemanticVersion::default(), cx);
384 gpui_tokio::init(cx);
385
386 let mut settings_store = SettingsStore::new(cx);
387 settings_store
388 .set_default_settings(settings::default_settings().as_ref(), cx)
389 .unwrap();
390 cx.set_global(settings_store);
391 client::init_settings(cx);
392
393 // Set User-Agent so we can download language servers from GitHub
394 let user_agent = format!(
395 "Zed/{} ({}; {})",
396 AppVersion::global(cx),
397 std::env::consts::OS,
398 std::env::consts::ARCH
399 );
400 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
401 let proxy_url = proxy_str
402 .as_ref()
403 .and_then(|input| input.parse::<Uri>().ok())
404 .or_else(read_proxy_from_env);
405 let http = {
406 let _guard = Tokio::handle(cx).enter();
407
408 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
409 .expect("could not start HTTP client")
410 };
411 cx.set_http_client(Arc::new(http));
412
413 Project::init_settings(cx);
414
415 let client = Client::production(cx);
416 cx.set_http_client(client.http_client().clone());
417
418 let git_binary_path = None;
419 let fs = Arc::new(RealFs::new(
420 git_binary_path,
421 cx.background_executor().clone(),
422 ));
423
424 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
425 languages.set_language_server_download_dir(paths::languages_dir().clone());
426 let languages = Arc::new(languages);
427
428 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
429
430 extension::init(cx);
431
432 let (tx, rx) = async_watch::channel(None);
433 cx.observe_global::<SettingsStore>(move |cx| {
434 let settings = &ProjectSettings::get_global(cx).node;
435 let options = NodeBinaryOptions {
436 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
437 allow_binary_download: true,
438 use_paths: settings.path.as_ref().map(|node_path| {
439 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
440 let npm_path = settings
441 .npm_path
442 .as_ref()
443 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
444 (
445 node_path.clone(),
446 npm_path.unwrap_or_else(|| {
447 let base_path = PathBuf::new();
448 node_path.parent().unwrap_or(&base_path).join("npm")
449 }),
450 )
451 }),
452 };
453 tx.send(Some(options)).log_err();
454 })
455 .detach();
456 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
457
458 let extension_host_proxy = ExtensionHostProxy::global(cx);
459
460 language::init(cx);
461 language_extension::init(extension_host_proxy.clone(), languages.clone());
462 language_model::init(client.clone(), cx);
463 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
464 languages::init(languages.clone(), node_runtime.clone(), cx);
465 assistant_tools::init(client.http_client().clone(), cx);
466 context_server::init(cx);
467 let stdout_is_a_pty = false;
468 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
469 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
470
471 SettingsStore::update_global(cx, |store, cx| {
472 store.set_user_settings(include_str!("../runner_settings.json"), cx)
473 })
474 .unwrap();
475
476 Arc::new(AgentAppState {
477 languages,
478 client,
479 user_store,
480 fs,
481 node_runtime,
482 prompt_builder,
483 })
484}
485
486pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
487 let model_registry = LanguageModelRegistry::read_global(cx);
488 let model = model_registry
489 .available_models(cx)
490 .find(|model| model.id().0 == model_name);
491
492 let Some(model) = model else {
493 return Err(anyhow!(
494 "No language model named {} was available. Available models: {}",
495 model_name,
496 model_registry
497 .available_models(cx)
498 .map(|model| model.id().0.clone())
499 .collect::<Vec<_>>()
500 .join(", ")
501 ));
502 };
503
504 Ok(model)
505}
506
507pub fn authenticate_model_provider(
508 provider_id: LanguageModelProviderId,
509 cx: &mut App,
510) -> Task<std::result::Result<(), AuthenticateError>> {
511 let model_registry = LanguageModelRegistry::read_global(cx);
512 let model_provider = model_registry.provider(&provider_id).unwrap();
513 model_provider.authenticate(cx)
514}