1mod example;
2
3use client::{Client, ProxySettings, UserStore};
4pub(crate) use example::*;
5
6use ::fs::RealFs;
7use anyhow::{Result, anyhow};
8use clap::Parser;
9use extension::ExtensionHostProxy;
10use futures::future;
11use gpui::http_client::{Uri, read_proxy_from_env};
12use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
13use gpui_tokio::Tokio;
14use language::LanguageRegistry;
15use language_model::{
16 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
17};
18use node_runtime::{NodeBinaryOptions, NodeRuntime};
19use project::Project;
20use project::project_settings::ProjectSettings;
21use prompt_store::PromptBuilder;
22use release_channel::AppVersion;
23use reqwest_client::ReqwestClient;
24use settings::{Settings, SettingsStore};
25use std::collections::HashSet;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use std::usize;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42 /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
43 #[arg(long, value_delimiter = ',')]
44 languages: Option<Vec<String>>,
45 /// How many times to run the judge on each example run.
46 #[arg(long, default_value = "3")]
47 judge_repetitions: u32,
48}
49
50fn main() {
51 env_logger::init();
52
53 let args = Args::parse();
54 let all_available_examples = list_all_examples().unwrap();
55 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
56
57 let example_paths = all_available_examples
58 .iter()
59 .filter_map(|example_path| {
60 let name = example_path.file_name()?.to_string_lossy();
61 if args.examples.is_empty()
62 || args
63 .examples
64 .iter()
65 .any(|name_substring| name.contains(name_substring))
66 {
67 Some(example_path.clone())
68 } else {
69 None
70 }
71 })
72 .collect::<Vec<_>>();
73
74 let http_client = Arc::new(ReqwestClient::new());
75 let app = Application::headless().with_http_client(http_client.clone());
76
77 app.run(move |cx| {
78 let app_state = init(cx);
79
80 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
81
82 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
83 registry.set_default_model(Some(model.clone()), cx);
84 });
85
86 let model_provider_id = model.provider_id();
87
88 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
89
90 cx.spawn(async move |cx| {
91 authenticate.await.unwrap();
92
93 std::fs::create_dir_all(REPOS_DIR)?;
94 std::fs::create_dir_all(WORKTREES_DIR)?;
95
96 let run_dir = Path::new(RUNS_DIR).join(format!(
97 "{}",
98 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
99 ));
100 std::fs::create_dir_all(&run_dir)?;
101
102 let mut examples = Vec::new();
103
104 const COLORS: [&str; 12] = [
105 "\x1b[31m", // Red
106 "\x1b[32m", // Green
107 "\x1b[33m", // Yellow
108 "\x1b[34m", // Blue
109 "\x1b[35m", // Magenta
110 "\x1b[36m", // Cyan
111 "\x1b[91m", // Bright Red
112 "\x1b[92m", // Bright Green
113 "\x1b[93m", // Bright Yellow
114 "\x1b[94m", // Bright Blue
115 "\x1b[95m", // Bright Magenta
116 "\x1b[96m", // Bright Cyan
117 ];
118
119 let mut max_name_width = 0;
120 let mut skipped = Vec::new();
121
122 for example_path in &example_paths {
123 let example = Example::load_from_directory(example_path, &run_dir)?;
124
125 if !example
126 .base
127 .language_extension
128 .as_ref()
129 .map_or(false, |lang| languages.contains(lang))
130 {
131 skipped.push(example.name);
132 continue;
133 }
134
135 let name_len = example.name.len();
136 if name_len > max_name_width {
137 max_name_width = example.name.len();
138 }
139
140 examples.push(example);
141 }
142
143 println!("Skipped examples: {}\n", skipped.join(", "));
144
145 if examples.is_empty() {
146 eprintln!("Filter matched no examples");
147 return cx.update(|cx| cx.quit());
148 }
149
150 let mut repo_urls = HashSet::new();
151 let mut clone_tasks = Vec::new();
152
153 for (i, example) in examples.iter_mut().enumerate() {
154 let color = COLORS[i % COLORS.len()].to_string();
155 example.set_log_prefix_style(&color, max_name_width);
156
157 println!(
158 "{}Logging to: {}",
159 example.log_prefix,
160 example.output_file_path.display()
161 );
162
163 let repo_url = example.base.url.clone();
164 if repo_urls.insert(repo_url.clone()) {
165 let repo_path = repo_path_for_url(&repo_url);
166
167 if !repo_path.join(".git").is_dir() {
168 println!(
169 "{:<width$} < {}",
170 "↓ Cloning",
171 repo_url,
172 width = max_name_width
173 );
174
175 let git_task = cx.spawn(async move |_cx| {
176 std::fs::create_dir_all(&repo_path)?;
177 run_git(&repo_path, &["init"]).await?;
178 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
179 });
180
181 clone_tasks.push(git_task);
182 } else {
183 println!(
184 "{:<width$} < {}",
185 "✔︎ Already cloned",
186 repo_url,
187 width = max_name_width
188 );
189
190 let actual_origin =
191 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
192 if actual_origin != repo_url {
193 return Err(anyhow!(
194 "remote origin {} does not match expected origin {}",
195 actual_origin,
196 repo_url,
197 ));
198 }
199 }
200 }
201 }
202
203 future::join_all(clone_tasks).await;
204
205 for example in examples.iter_mut() {
206 example.setup().await?;
207 }
208
209 let judge_repetitions = args.judge_repetitions;
210 let tasks = examples
211 .into_iter()
212 .map(|example| {
213 let app_state = app_state.clone();
214 let model = model.clone();
215 cx.spawn(async move |cx| {
216 (
217 run_example(&example, model, app_state, judge_repetitions, cx).await,
218 example,
219 )
220 })
221 })
222 .collect::<Vec<_>>();
223
224 let results: Vec<(Result<Vec<Result<JudgeOutput>>>, Example)> =
225 future::join_all(tasks).await;
226
227 println!("\n\n");
228 println!("========================================");
229 println!(" EVAL RESULTS ");
230 println!("========================================");
231 println!("");
232
233 let mut judge_scores = Vec::new();
234
235 for (result, example) in results {
236 match result {
237 Err(err) => {
238 println!("💥 {}{:?}", example.log_prefix, err);
239 }
240 Ok(judge_results) => {
241 for judge_result in judge_results {
242 match judge_result {
243 Ok(judge_output) => {
244 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
245
246 println!(
247 "{} {}{}",
248 SCORES[judge_output.score.min(5) as usize],
249 example.log_prefix,
250 judge_output.score,
251 );
252 judge_scores.push(judge_output.score);
253 }
254 Err(err) => {
255 println!("💥 {}{:?}", example.log_prefix, err);
256 }
257 }
258 }
259 }
260 }
261 println!(
262 "{} > {}",
263 " ".repeat(max_name_width),
264 example.output_file_path.display()
265 );
266 }
267
268 let score_count = judge_scores.len();
269 let average_score = judge_scores
270 .into_iter()
271 .map(|score| score as f32)
272 .sum::<f32>()
273 / (score_count as f32);
274 println!("\nAverage score: {average_score}");
275
276 cx.update(|cx| cx.quit())
277 })
278 .detach_and_log_err(cx);
279 });
280}
281
282async fn run_example(
283 example: &Example,
284 model: Arc<dyn LanguageModel>,
285 app_state: Arc<AgentAppState>,
286 judge_repetitions: u32,
287 cx: &mut AsyncApp,
288) -> Result<Vec<Result<JudgeOutput>>> {
289 cx.update(|cx| example.run(model.clone(), app_state, cx))?
290 .await?;
291 let diff = example.repository_diff().await?;
292
293 let judge_tasks = (0..judge_repetitions)
294 .map(|round| example.judge(model.clone(), diff.clone(), round, cx))
295 .collect::<Vec<_>>();
296
297 Ok(future::join_all(judge_tasks).await)
298}
299
300fn list_all_examples() -> Result<Vec<PathBuf>> {
301 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
302 let entries = std::fs::read_dir(path).unwrap();
303 let mut result_paths = Vec::new();
304 for entry in entries {
305 let entry = entry?;
306 let path = entry.path();
307 if path.is_dir() {
308 result_paths.push(path);
309 }
310 }
311 Ok(result_paths)
312}
313
314/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
315pub struct AgentAppState {
316 pub languages: Arc<LanguageRegistry>,
317 pub client: Arc<Client>,
318 pub user_store: Entity<UserStore>,
319 pub fs: Arc<dyn fs::Fs>,
320 pub node_runtime: NodeRuntime,
321
322 // Additional fields not present in `workspace::AppState`.
323 pub prompt_builder: Arc<PromptBuilder>,
324}
325
326pub fn init(cx: &mut App) -> Arc<AgentAppState> {
327 release_channel::init(SemanticVersion::default(), cx);
328 gpui_tokio::init(cx);
329
330 let mut settings_store = SettingsStore::new(cx);
331 settings_store
332 .set_default_settings(settings::default_settings().as_ref(), cx)
333 .unwrap();
334 cx.set_global(settings_store);
335 client::init_settings(cx);
336
337 // Set User-Agent so we can download language servers from GitHub
338 let user_agent = format!(
339 "Zed/{} ({}; {})",
340 AppVersion::global(cx),
341 std::env::consts::OS,
342 std::env::consts::ARCH
343 );
344 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
345 let proxy_url = proxy_str
346 .as_ref()
347 .and_then(|input| input.parse::<Uri>().ok())
348 .or_else(read_proxy_from_env);
349 let http = {
350 let _guard = Tokio::handle(cx).enter();
351
352 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
353 .expect("could not start HTTP client")
354 };
355 cx.set_http_client(Arc::new(http));
356
357 Project::init_settings(cx);
358
359 let client = Client::production(cx);
360 cx.set_http_client(client.http_client().clone());
361
362 let git_binary_path = None;
363 let fs = Arc::new(RealFs::new(
364 git_binary_path,
365 cx.background_executor().clone(),
366 ));
367
368 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
369 languages.set_language_server_download_dir(paths::languages_dir().clone());
370 let languages = Arc::new(languages);
371
372 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
373
374 extension::init(cx);
375
376 let (tx, rx) = async_watch::channel(None);
377 cx.observe_global::<SettingsStore>(move |cx| {
378 let settings = &ProjectSettings::get_global(cx).node;
379 let options = NodeBinaryOptions {
380 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
381 allow_binary_download: true,
382 use_paths: settings.path.as_ref().map(|node_path| {
383 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
384 let npm_path = settings
385 .npm_path
386 .as_ref()
387 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
388 (
389 node_path.clone(),
390 npm_path.unwrap_or_else(|| {
391 let base_path = PathBuf::new();
392 node_path.parent().unwrap_or(&base_path).join("npm")
393 }),
394 )
395 }),
396 };
397 tx.send(Some(options)).log_err();
398 })
399 .detach();
400 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
401
402 let extension_host_proxy = ExtensionHostProxy::global(cx);
403
404 language::init(cx);
405 language_extension::init(extension_host_proxy.clone(), languages.clone());
406 language_model::init(client.clone(), cx);
407 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
408 languages::init(languages.clone(), node_runtime.clone(), cx);
409 assistant_tools::init(client.http_client().clone(), cx);
410 context_server::init(cx);
411 let stdout_is_a_pty = false;
412 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
413 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
414
415 SettingsStore::update_global(cx, |store, cx| {
416 store.set_user_settings(include_str!("../runner_settings.json"), cx)
417 })
418 .unwrap();
419
420 Arc::new(AgentAppState {
421 languages,
422 client,
423 user_store,
424 fs,
425 node_runtime,
426 prompt_builder,
427 })
428}
429
430pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
431 let model_registry = LanguageModelRegistry::read_global(cx);
432 let model = model_registry
433 .available_models(cx)
434 .find(|model| model.id().0 == model_name);
435
436 let Some(model) = model else {
437 return Err(anyhow!(
438 "No language model named {} was available. Available models: {}",
439 model_name,
440 model_registry
441 .available_models(cx)
442 .map(|model| model.id().0.clone())
443 .collect::<Vec<_>>()
444 .join(", ")
445 ));
446 };
447
448 Ok(model)
449}
450
451pub fn authenticate_model_provider(
452 provider_id: LanguageModelProviderId,
453 cx: &mut App,
454) -> Task<std::result::Result<(), AuthenticateError>> {
455 let model_registry = LanguageModelRegistry::read_global(cx);
456 let model_provider = model_registry.provider(&provider_id).unwrap();
457 model_provider.authenticate(cx)
458}