1mod example;
2
3use client::{Client, ProxySettings, UserStore};
4pub(crate) use example::*;
5
6use ::fs::RealFs;
7use anyhow::{Result, anyhow};
8use clap::Parser;
9use extension::ExtensionHostProxy;
10use futures::future;
11use gpui::http_client::{Uri, read_proxy_from_env};
12use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
13use gpui_tokio::Tokio;
14use language::LanguageRegistry;
15use language_model::{
16 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
17};
18use node_runtime::{NodeBinaryOptions, NodeRuntime};
19use project::Project;
20use project::project_settings::ProjectSettings;
21use prompt_store::PromptBuilder;
22use release_channel::AppVersion;
23use reqwest_client::ReqwestClient;
24use settings::{Settings, SettingsStore};
25use std::collections::HashSet;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use std::usize;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42 /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
43 #[arg(long, value_delimiter = ',')]
44 languages: Option<Vec<String>>,
45}
46
47fn main() {
48 env_logger::init();
49
50 let args = Args::parse();
51 let all_available_examples = list_all_examples().unwrap();
52 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
53
54 let example_paths = all_available_examples
55 .iter()
56 .filter_map(|example_path| {
57 let name = example_path.file_name()?.to_string_lossy();
58 if args.examples.is_empty()
59 || args
60 .examples
61 .iter()
62 .any(|name_substring| name.contains(name_substring))
63 {
64 Some(example_path.clone())
65 } else {
66 None
67 }
68 })
69 .collect::<Vec<_>>();
70
71 let http_client = Arc::new(ReqwestClient::new());
72 let app = Application::headless().with_http_client(http_client.clone());
73
74 app.run(move |cx| {
75 let app_state = init(cx);
76
77 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
78
79 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
80 registry.set_default_model(Some(model.clone()), cx);
81 });
82
83 let model_provider_id = model.provider_id();
84
85 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
86
87 cx.spawn(async move |cx| {
88 authenticate.await.unwrap();
89
90 std::fs::create_dir_all(REPOS_DIR)?;
91 std::fs::create_dir_all(WORKTREES_DIR)?;
92
93 let run_dir = Path::new(RUNS_DIR).join(format!(
94 "{}",
95 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
96 ));
97 std::fs::create_dir_all(&run_dir)?;
98
99 let mut examples = Vec::new();
100
101 const COLORS: [&str; 12] = [
102 "\x1b[31m", // Red
103 "\x1b[32m", // Green
104 "\x1b[33m", // Yellow
105 "\x1b[34m", // Blue
106 "\x1b[35m", // Magenta
107 "\x1b[36m", // Cyan
108 "\x1b[91m", // Bright Red
109 "\x1b[92m", // Bright Green
110 "\x1b[93m", // Bright Yellow
111 "\x1b[94m", // Bright Blue
112 "\x1b[95m", // Bright Magenta
113 "\x1b[96m", // Bright Cyan
114 ];
115
116 let mut max_name_width = 0;
117 let mut skipped = Vec::new();
118
119 for example_path in &example_paths {
120 let example = Example::load_from_directory(example_path, &run_dir)?;
121
122 if !example
123 .base
124 .language_extension
125 .as_ref()
126 .map_or(false, |lang| languages.contains(lang))
127 {
128 skipped.push(example.name);
129 continue;
130 }
131
132 let name_len = example.name.len();
133 if name_len > max_name_width {
134 max_name_width = example.name.len();
135 }
136
137 examples.push(example);
138 }
139
140 println!("Skipped examples: {}\n", skipped.join(", "));
141
142 if examples.is_empty() {
143 eprintln!("Filter matched no examples");
144 return cx.update(|cx| cx.quit());
145 }
146
147 let mut repo_urls = HashSet::new();
148 let mut clone_tasks = Vec::new();
149
150 for (i, example) in examples.iter_mut().enumerate() {
151 let color = COLORS[i % COLORS.len()].to_string();
152 example.set_log_prefix_style(&color, max_name_width);
153
154 println!(
155 "{}Logging to: {}",
156 example.log_prefix,
157 example.output_file_path.display()
158 );
159
160 let repo_url = example.base.url.clone();
161 if repo_urls.insert(repo_url.clone()) {
162 let repo_path = repo_path_for_url(&repo_url);
163
164 if !repo_path.join(".git").is_dir() {
165 println!(
166 "{:<width$} < {}",
167 "↓ Cloning",
168 repo_url,
169 width = max_name_width
170 );
171
172 let git_task = cx.spawn(async move |_cx| {
173 std::fs::create_dir_all(&repo_path)?;
174 run_git(&repo_path, &["init"]).await?;
175 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
176 });
177
178 clone_tasks.push(git_task);
179 } else {
180 println!(
181 "{:<width$} < {}",
182 "✔︎ Already cloned",
183 repo_url,
184 width = max_name_width
185 );
186
187 let actual_origin =
188 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
189 if actual_origin != repo_url {
190 return Err(anyhow!(
191 "remote origin {} does not match expected origin {}",
192 actual_origin,
193 repo_url,
194 ));
195 }
196 }
197 }
198 }
199
200 future::join_all(clone_tasks).await;
201
202 for example in examples.iter_mut() {
203 example.setup().await?;
204 }
205
206 let tasks = examples
207 .into_iter()
208 .map(|example| {
209 let app_state = app_state.clone();
210 let model = model.clone();
211 cx.spawn(async move |cx| {
212 (run_example(&example, model, app_state, cx).await, example)
213 })
214 })
215 .collect::<Vec<_>>();
216
217 let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
218
219 println!("\n\n");
220 println!("========================================");
221 println!(" EVAL RESULTS ");
222 println!("========================================");
223 println!("");
224
225 let mut judge_scores = Vec::new();
226
227 for (result, example) in results {
228 match result {
229 Err(err) => {
230 println!("💥 {}{:?}", example.log_prefix, err);
231 }
232 Ok(judge_output) => {
233 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
234
235 println!(
236 "{} {}{}",
237 SCORES[judge_output.score.min(5) as usize],
238 example.log_prefix,
239 judge_output.score,
240 );
241 judge_scores.push(judge_output.score);
242 }
243 }
244 println!(
245 "{} > {}",
246 " ".repeat(max_name_width),
247 example.output_file_path.display()
248 );
249 }
250
251 let score_count = judge_scores.len();
252 let average_score = judge_scores
253 .into_iter()
254 .map(|score| score as f32)
255 .sum::<f32>()
256 / (score_count as f32);
257 println!("\nAverage score: {average_score}");
258
259 cx.update(|cx| cx.quit())
260 })
261 .detach_and_log_err(cx);
262 });
263}
264
265async fn run_example(
266 example: &Example,
267 model: Arc<dyn LanguageModel>,
268 app_state: Arc<AgentAppState>,
269 cx: &mut AsyncApp,
270) -> Result<JudgeOutput> {
271 cx.update(|cx| example.run(model.clone(), app_state, cx))?
272 .await?;
273 let diff = example.repository_diff().await?;
274 example.judge(model, diff, cx).await
275}
276
277fn list_all_examples() -> Result<Vec<PathBuf>> {
278 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
279 let entries = std::fs::read_dir(path).unwrap();
280 let mut result_paths = Vec::new();
281 for entry in entries {
282 let entry = entry?;
283 let path = entry.path();
284 if path.is_dir() {
285 result_paths.push(path);
286 }
287 }
288 Ok(result_paths)
289}
290
291/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
292pub struct AgentAppState {
293 pub languages: Arc<LanguageRegistry>,
294 pub client: Arc<Client>,
295 pub user_store: Entity<UserStore>,
296 pub fs: Arc<dyn fs::Fs>,
297 pub node_runtime: NodeRuntime,
298
299 // Additional fields not present in `workspace::AppState`.
300 pub prompt_builder: Arc<PromptBuilder>,
301}
302
303pub fn init(cx: &mut App) -> Arc<AgentAppState> {
304 release_channel::init(SemanticVersion::default(), cx);
305 gpui_tokio::init(cx);
306
307 let mut settings_store = SettingsStore::new(cx);
308 settings_store
309 .set_default_settings(settings::default_settings().as_ref(), cx)
310 .unwrap();
311 cx.set_global(settings_store);
312 client::init_settings(cx);
313
314 // Set User-Agent so we can download language servers from GitHub
315 let user_agent = format!(
316 "Zed/{} ({}; {})",
317 AppVersion::global(cx),
318 std::env::consts::OS,
319 std::env::consts::ARCH
320 );
321 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
322 let proxy_url = proxy_str
323 .as_ref()
324 .and_then(|input| input.parse::<Uri>().ok())
325 .or_else(read_proxy_from_env);
326 let http = {
327 let _guard = Tokio::handle(cx).enter();
328
329 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
330 .expect("could not start HTTP client")
331 };
332 cx.set_http_client(Arc::new(http));
333
334 Project::init_settings(cx);
335
336 let client = Client::production(cx);
337 cx.set_http_client(client.http_client().clone());
338
339 let git_binary_path = None;
340 let fs = Arc::new(RealFs::new(
341 git_binary_path,
342 cx.background_executor().clone(),
343 ));
344
345 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
346 languages.set_language_server_download_dir(paths::languages_dir().clone());
347 let languages = Arc::new(languages);
348
349 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
350
351 extension::init(cx);
352
353 let (tx, rx) = async_watch::channel(None);
354 cx.observe_global::<SettingsStore>(move |cx| {
355 let settings = &ProjectSettings::get_global(cx).node;
356 let options = NodeBinaryOptions {
357 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
358 allow_binary_download: true,
359 use_paths: settings.path.as_ref().map(|node_path| {
360 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
361 let npm_path = settings
362 .npm_path
363 .as_ref()
364 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
365 (
366 node_path.clone(),
367 npm_path.unwrap_or_else(|| {
368 let base_path = PathBuf::new();
369 node_path.parent().unwrap_or(&base_path).join("npm")
370 }),
371 )
372 }),
373 };
374 tx.send(Some(options)).log_err();
375 })
376 .detach();
377 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
378
379 let extension_host_proxy = ExtensionHostProxy::global(cx);
380
381 language::init(cx);
382 language_extension::init(extension_host_proxy.clone(), languages.clone());
383 language_model::init(client.clone(), cx);
384 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
385 languages::init(languages.clone(), node_runtime.clone(), cx);
386 assistant_tools::init(client.http_client().clone(), cx);
387 context_server::init(cx);
388 let stdout_is_a_pty = false;
389 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
390 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
391
392 SettingsStore::update_global(cx, |store, cx| {
393 store.set_user_settings(include_str!("../runner_settings.json"), cx)
394 })
395 .unwrap();
396
397 Arc::new(AgentAppState {
398 languages,
399 client,
400 user_store,
401 fs,
402 node_runtime,
403 prompt_builder,
404 })
405}
406
407pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
408 let model_registry = LanguageModelRegistry::read_global(cx);
409 let model = model_registry
410 .available_models(cx)
411 .find(|model| model.id().0 == model_name);
412
413 let Some(model) = model else {
414 return Err(anyhow!(
415 "No language model named {} was available. Available models: {}",
416 model_name,
417 model_registry
418 .available_models(cx)
419 .map(|model| model.id().0.clone())
420 .collect::<Vec<_>>()
421 .join(", ")
422 ));
423 };
424
425 Ok(model)
426}
427
428pub fn authenticate_model_provider(
429 provider_id: LanguageModelProviderId,
430 cx: &mut App,
431) -> Task<std::result::Result<(), AuthenticateError>> {
432 let model_registry = LanguageModelRegistry::read_global(cx);
433 let model_provider = model_registry.provider(&provider_id).unwrap();
434 model_provider.authenticate(cx)
435}