1mod example;
2mod ids;
3
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6use telemetry;
7
8use ::fs::RealFs;
9use anyhow::{Result, anyhow};
10use clap::Parser;
11use extension::ExtensionHostProxy;
12use futures::future;
13use gpui::http_client::{Uri, read_proxy_from_env};
14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
15use gpui_tokio::Tokio;
16use language::LanguageRegistry;
17use language_model::{
18 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
19};
20use node_runtime::{NodeBinaryOptions, NodeRuntime};
21use project::Project;
22use project::project_settings::ProjectSettings;
23use prompt_store::PromptBuilder;
24use release_channel::AppVersion;
25use reqwest_client::ReqwestClient;
26use settings::{Settings, SettingsStore};
27use std::collections::HashSet;
28use std::path::{Path, PathBuf};
29use std::sync::Arc;
30use std::usize;
31use util::ResultExt as _;
32
33pub const RUNS_DIR: &str = "./crates/eval/runs";
34
35#[derive(Parser, Debug)]
36#[command(name = "eval", disable_version_flag = true)]
37struct Args {
38 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
39 #[arg(value_name = "EXAMPLE_SUBSTRING")]
40 examples: Vec<String>,
41 /// Model to use (default: "claude-3-7-sonnet-latest")
42 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
43 model: String,
44 #[arg(long, value_delimiter = ',')]
45 languages: Option<Vec<String>>,
46 /// How many times to run the judge on each example run.
47 #[arg(long, default_value = "3")]
48 judge_repetitions: u32,
49}
50
51fn main() {
52 env_logger::init();
53
54 let args = Args::parse();
55 let all_available_examples = list_all_examples().unwrap();
56 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
57
58 let example_paths = all_available_examples
59 .iter()
60 .filter_map(|example_path| {
61 let name = example_path.file_name()?.to_string_lossy();
62 if args.examples.is_empty()
63 || args
64 .examples
65 .iter()
66 .any(|name_substring| name.contains(name_substring))
67 {
68 Some(example_path.clone())
69 } else {
70 None
71 }
72 })
73 .collect::<Vec<_>>();
74
75 let http_client = Arc::new(ReqwestClient::new());
76 let app = Application::headless().with_http_client(http_client.clone());
77
78 app.run(move |cx| {
79 let app_state = init(cx);
80
81 let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
82 let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
83 let session_id = uuid::Uuid::new_v4().to_string();
84
85 app_state
86 .client
87 .telemetry()
88 .start(system_id, installation_id, session_id, cx);
89
90 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
91
92 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
93 registry.set_default_model(Some(model.clone()), cx);
94 });
95
96 let model_provider_id = model.provider_id();
97
98 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
99
100 cx.spawn(async move |cx| {
101 authenticate.await.unwrap();
102
103 std::fs::create_dir_all(REPOS_DIR)?;
104 std::fs::create_dir_all(WORKTREES_DIR)?;
105
106 let run_dir = Path::new(RUNS_DIR).join(format!(
107 "{}",
108 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
109 ));
110 std::fs::create_dir_all(&run_dir)?;
111
112 let mut examples = Vec::new();
113
114 const COLORS: [&str; 12] = [
115 "\x1b[31m", // Red
116 "\x1b[32m", // Green
117 "\x1b[33m", // Yellow
118 "\x1b[34m", // Blue
119 "\x1b[35m", // Magenta
120 "\x1b[36m", // Cyan
121 "\x1b[91m", // Bright Red
122 "\x1b[92m", // Bright Green
123 "\x1b[93m", // Bright Yellow
124 "\x1b[94m", // Bright Blue
125 "\x1b[95m", // Bright Magenta
126 "\x1b[96m", // Bright Cyan
127 ];
128
129 let mut max_name_width = 0;
130 let mut skipped = Vec::new();
131
132 for example_path in &example_paths {
133 let example = Example::load_from_directory(example_path, &run_dir)?;
134
135 if !example
136 .base
137 .language_extension
138 .as_ref()
139 .map_or(false, |lang| languages.contains(lang))
140 {
141 skipped.push(example.name);
142 continue;
143 }
144
145 let name_len = example.name.len();
146 if name_len > max_name_width {
147 max_name_width = example.name.len();
148 }
149
150 examples.push(example);
151 }
152
153 println!("Skipped examples: {}\n", skipped.join(", "));
154
155 if examples.is_empty() {
156 eprintln!("Filter matched no examples");
157 return cx.update(|cx| cx.quit());
158 }
159
160 let mut repo_urls = HashSet::new();
161 let mut clone_tasks = Vec::new();
162
163 for (i, example) in examples.iter_mut().enumerate() {
164 let color = COLORS[i % COLORS.len()].to_string();
165 example.set_log_prefix_style(&color, max_name_width);
166
167 println!(
168 "{}Logging to: {}",
169 example.log_prefix,
170 example.output_file_path.display()
171 );
172
173 let repo_url = example.base.url.clone();
174 if repo_urls.insert(repo_url.clone()) {
175 let repo_path = repo_path_for_url(&repo_url);
176
177 if !repo_path.join(".git").is_dir() {
178 println!(
179 "{:<width$} < {}",
180 "↓ Cloning",
181 repo_url,
182 width = max_name_width
183 );
184
185 let git_task = cx.spawn(async move |_cx| {
186 std::fs::create_dir_all(&repo_path)?;
187 run_git(&repo_path, &["init"]).await?;
188 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
189 });
190
191 clone_tasks.push(git_task);
192 } else {
193 println!(
194 "{:<width$} < {}",
195 "✔︎ Already cloned",
196 repo_url,
197 width = max_name_width
198 );
199
200 let actual_origin =
201 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
202 if actual_origin != repo_url {
203 return Err(anyhow!(
204 "remote origin {} does not match expected origin {}",
205 actual_origin,
206 repo_url,
207 ));
208 }
209 }
210 }
211 }
212
213 future::join_all(clone_tasks).await;
214
215 for example in examples.iter_mut() {
216 example.setup().await?;
217 }
218
219 let judge_repetitions = args.judge_repetitions;
220 let tasks = examples
221 .into_iter()
222 .map(|example| {
223 let app_state = app_state.clone();
224 let model = model.clone();
225 cx.spawn(async move |cx| {
226 (
227 run_example(&example, model, app_state, judge_repetitions, cx).await,
228 example,
229 )
230 })
231 })
232 .collect::<Vec<_>>();
233
234 let results: Vec<(Result<Vec<Result<JudgeOutput>>>, Example)> =
235 future::join_all(tasks).await;
236
237 println!("\n\n");
238 println!("========================================");
239 println!(" EVAL RESULTS ");
240 println!("========================================");
241 println!("");
242
243 let mut judge_scores = Vec::new();
244
245 for (result, example) in results {
246 match result {
247 Err(err) => {
248 println!("💥 {}{:?}", example.log_prefix, err);
249 }
250 Ok(judge_results) => {
251 for judge_result in judge_results {
252 match judge_result {
253 Ok(judge_output) => {
254 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
255
256 println!(
257 "{} {}{}",
258 SCORES[judge_output.score.min(5) as usize],
259 example.log_prefix,
260 judge_output.score,
261 );
262 judge_scores.push(judge_output.score);
263 }
264 Err(err) => {
265 println!("💥 {}{:?}", example.log_prefix, err);
266 }
267 }
268 }
269 }
270 }
271 println!(
272 "{} > {}",
273 " ".repeat(max_name_width),
274 example.output_file_path.display()
275 );
276 }
277
278 let score_count = judge_scores.len();
279 let average_score = judge_scores
280 .into_iter()
281 .map(|score| score as f32)
282 .sum::<f32>()
283 / (score_count as f32);
284 println!("\nAverage score: {average_score}");
285
286 std::thread::sleep(std::time::Duration::from_secs(2));
287
288 // Flush telemetry events before exiting
289 app_state.client.telemetry().flush_events();
290
291 cx.update(|cx| cx.quit())
292 })
293 .detach_and_log_err(cx);
294 });
295}
296
297async fn run_example(
298 example: &Example,
299 model: Arc<dyn LanguageModel>,
300 app_state: Arc<AgentAppState>,
301 judge_repetitions: u32,
302 cx: &mut AsyncApp,
303) -> Result<Vec<Result<JudgeOutput>>> {
304 let run_output = cx
305 .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
306 .await?;
307 let diff = example.repository_diff().await?;
308
309 // Run judge for each repetition
310 let mut results = Vec::new();
311 for round in 0..judge_repetitions {
312 let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
313
314 // Log telemetry for this judge result
315 if let Ok(judge_output) = &judge_result {
316 let cohort_id = example
317 .output_file_path
318 .parent()
319 .and_then(|p| p.file_name())
320 .map(|name| name.to_string_lossy().to_string())
321 .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
322
323 telemetry::event!(
324 "Agent Eval Completed",
325 cohort_id = cohort_id,
326 example_name = example.name.clone(),
327 round = round,
328 score = judge_output.score,
329 analysis = judge_output.analysis,
330 tool_use_counts = run_output.tool_use_counts,
331 response_count = run_output.response_count,
332 token_usage = run_output.token_usage,
333 model = model.telemetry_id(),
334 model_provider = model.provider_id().to_string(),
335 repository_url = example.base.url.clone(),
336 repository_revision = example.base.revision.clone(),
337 diagnostics_summary = run_output.diagnostics
338 );
339 }
340
341 results.push(judge_result);
342 }
343
344 app_state.client.telemetry().flush_events();
345
346 Ok(results)
347}
348
349fn list_all_examples() -> Result<Vec<PathBuf>> {
350 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
351 let entries = std::fs::read_dir(path).unwrap();
352 let mut result_paths = Vec::new();
353 for entry in entries {
354 let entry = entry?;
355 let path = entry.path();
356 if path.is_dir() {
357 result_paths.push(path);
358 }
359 }
360 Ok(result_paths)
361}
362
363/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
364pub struct AgentAppState {
365 pub languages: Arc<LanguageRegistry>,
366 pub client: Arc<Client>,
367 pub user_store: Entity<UserStore>,
368 pub fs: Arc<dyn fs::Fs>,
369 pub node_runtime: NodeRuntime,
370
371 // Additional fields not present in `workspace::AppState`.
372 pub prompt_builder: Arc<PromptBuilder>,
373}
374
375pub fn init(cx: &mut App) -> Arc<AgentAppState> {
376 release_channel::init(SemanticVersion::default(), cx);
377 gpui_tokio::init(cx);
378
379 let mut settings_store = SettingsStore::new(cx);
380 settings_store
381 .set_default_settings(settings::default_settings().as_ref(), cx)
382 .unwrap();
383 cx.set_global(settings_store);
384 client::init_settings(cx);
385
386 // Set User-Agent so we can download language servers from GitHub
387 let user_agent = format!(
388 "Zed/{} ({}; {})",
389 AppVersion::global(cx),
390 std::env::consts::OS,
391 std::env::consts::ARCH
392 );
393 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
394 let proxy_url = proxy_str
395 .as_ref()
396 .and_then(|input| input.parse::<Uri>().ok())
397 .or_else(read_proxy_from_env);
398 let http = {
399 let _guard = Tokio::handle(cx).enter();
400
401 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
402 .expect("could not start HTTP client")
403 };
404 cx.set_http_client(Arc::new(http));
405
406 Project::init_settings(cx);
407
408 let client = Client::production(cx);
409 cx.set_http_client(client.http_client().clone());
410
411 let git_binary_path = None;
412 let fs = Arc::new(RealFs::new(
413 git_binary_path,
414 cx.background_executor().clone(),
415 ));
416
417 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
418 languages.set_language_server_download_dir(paths::languages_dir().clone());
419 let languages = Arc::new(languages);
420
421 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
422
423 extension::init(cx);
424
425 let (tx, rx) = async_watch::channel(None);
426 cx.observe_global::<SettingsStore>(move |cx| {
427 let settings = &ProjectSettings::get_global(cx).node;
428 let options = NodeBinaryOptions {
429 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
430 allow_binary_download: true,
431 use_paths: settings.path.as_ref().map(|node_path| {
432 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
433 let npm_path = settings
434 .npm_path
435 .as_ref()
436 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
437 (
438 node_path.clone(),
439 npm_path.unwrap_or_else(|| {
440 let base_path = PathBuf::new();
441 node_path.parent().unwrap_or(&base_path).join("npm")
442 }),
443 )
444 }),
445 };
446 tx.send(Some(options)).log_err();
447 })
448 .detach();
449 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
450
451 let extension_host_proxy = ExtensionHostProxy::global(cx);
452
453 language::init(cx);
454 language_extension::init(extension_host_proxy.clone(), languages.clone());
455 language_model::init(client.clone(), cx);
456 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
457 languages::init(languages.clone(), node_runtime.clone(), cx);
458 assistant_tools::init(client.http_client().clone(), cx);
459 context_server::init(cx);
460 let stdout_is_a_pty = false;
461 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
462 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
463
464 SettingsStore::update_global(cx, |store, cx| {
465 store.set_user_settings(include_str!("../runner_settings.json"), cx)
466 })
467 .unwrap();
468
469 Arc::new(AgentAppState {
470 languages,
471 client,
472 user_store,
473 fs,
474 node_runtime,
475 prompt_builder,
476 })
477}
478
479pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
480 let model_registry = LanguageModelRegistry::read_global(cx);
481 let model = model_registry
482 .available_models(cx)
483 .find(|model| model.id().0 == model_name);
484
485 let Some(model) = model else {
486 return Err(anyhow!(
487 "No language model named {} was available. Available models: {}",
488 model_name,
489 model_registry
490 .available_models(cx)
491 .map(|model| model.id().0.clone())
492 .collect::<Vec<_>>()
493 .join(", ")
494 ));
495 };
496
497 Ok(model)
498}
499
500pub fn authenticate_model_provider(
501 provider_id: LanguageModelProviderId,
502 cx: &mut App,
503) -> Task<std::result::Result<(), AuthenticateError>> {
504 let model_registry = LanguageModelRegistry::read_global(cx);
505 let model_provider = model_registry.provider(&provider_id).unwrap();
506 model_provider.authenticate(cx)
507}