1mod example;
2
3use assistant_settings::AssistantSettings;
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6
7use ::fs::RealFs;
8use anyhow::{Result, anyhow};
9use clap::Parser;
10use extension::ExtensionHostProxy;
11use futures::future;
12use gpui::http_client::{Uri, read_proxy_from_env};
13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
14use gpui_tokio::Tokio;
15use language::LanguageRegistry;
16use language_model::{
17 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
18};
19use node_runtime::{NodeBinaryOptions, NodeRuntime};
20use project::Project;
21use project::project_settings::ProjectSettings;
22use prompt_store::PromptBuilder;
23use release_channel::AppVersion;
24use reqwest_client::ReqwestClient;
25use settings::{Settings, SettingsStore};
26use std::collections::HashSet;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use std::usize;
30use util::ResultExt as _;
31
32pub const RUNS_DIR: &str = "./crates/eval/runs";
33
34#[derive(Parser, Debug)]
35#[command(name = "eval", disable_version_flag = true)]
36struct Args {
37 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
38 #[arg(value_name = "EXAMPLE_SUBSTRING")]
39 examples: Vec<String>,
40 /// Model to use (default: "claude-3-7-sonnet-latest")
41 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
42 model: String,
43 /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
44 #[arg(long, value_delimiter = ',')]
45 languages: Option<Vec<String>>,
46}
47
48fn main() {
49 env_logger::init();
50
51 let args = Args::parse();
52 let all_available_examples = list_all_examples().unwrap();
53 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
54
55 let example_paths = all_available_examples
56 .iter()
57 .filter_map(|example_path| {
58 let name = example_path.file_name()?.to_string_lossy();
59 if args.examples.is_empty()
60 || args
61 .examples
62 .iter()
63 .any(|name_substring| name.contains(name_substring))
64 {
65 Some(example_path.clone())
66 } else {
67 None
68 }
69 })
70 .collect::<Vec<_>>();
71
72 let http_client = Arc::new(ReqwestClient::new());
73 let app = Application::headless().with_http_client(http_client.clone());
74
75 app.run(move |cx| {
76 let app_state = init(cx);
77
78 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
79
80 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
81 registry.set_default_model(Some(model.clone()), cx);
82 });
83
84 let model_provider_id = model.provider_id();
85
86 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
87
88 cx.spawn(async move |cx| {
89 authenticate.await.unwrap();
90
91 std::fs::create_dir_all(REPOS_DIR)?;
92 std::fs::create_dir_all(WORKTREES_DIR)?;
93
94 let run_dir = Path::new(RUNS_DIR).join(format!(
95 "{}",
96 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
97 ));
98 std::fs::create_dir_all(&run_dir)?;
99
100 let mut examples = Vec::new();
101
102 const COLORS: [&str; 12] = [
103 "\x1b[31m", // Red
104 "\x1b[32m", // Green
105 "\x1b[33m", // Yellow
106 "\x1b[34m", // Blue
107 "\x1b[35m", // Magenta
108 "\x1b[36m", // Cyan
109 "\x1b[91m", // Bright Red
110 "\x1b[92m", // Bright Green
111 "\x1b[93m", // Bright Yellow
112 "\x1b[94m", // Bright Blue
113 "\x1b[95m", // Bright Magenta
114 "\x1b[96m", // Bright Cyan
115 ];
116
117 let mut max_name_width = 0;
118 let mut skipped = Vec::new();
119
120 for example_path in &example_paths {
121 let example = Example::load_from_directory(example_path, &run_dir)?;
122
123 if !example
124 .base
125 .language_extension
126 .as_ref()
127 .map_or(false, |lang| languages.contains(lang))
128 {
129 skipped.push(example.name);
130 continue;
131 }
132
133 let name_len = example.name.len();
134 if name_len > max_name_width {
135 max_name_width = example.name.len();
136 }
137
138 examples.push(example);
139 }
140
141 println!("Skipped examples: {}\n", skipped.join(", "));
142
143 if examples.is_empty() {
144 eprintln!("Filter matched no examples");
145 return cx.update(|cx| cx.quit());
146 }
147
148 let mut repo_urls = HashSet::new();
149 let mut clone_tasks = Vec::new();
150
151 for (i, example) in examples.iter_mut().enumerate() {
152 let color = COLORS[i % COLORS.len()].to_string();
153 example.set_log_prefix_style(&color, max_name_width);
154
155 println!(
156 "{}Logging to: {}",
157 example.log_prefix,
158 example.output_file_path.display()
159 );
160
161 let repo_url = example.base.url.clone();
162 if repo_urls.insert(repo_url.clone()) {
163 let repo_path = repo_path_for_url(&repo_url);
164
165 if !repo_path.join(".git").is_dir() {
166 println!(
167 "{:<width$} < {}",
168 "↓ Cloning",
169 repo_url,
170 width = max_name_width
171 );
172
173 let git_task = cx.spawn(async move |_cx| {
174 std::fs::create_dir_all(&repo_path)?;
175 run_git(&repo_path, &["init"]).await?;
176 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
177 });
178
179 clone_tasks.push(git_task);
180 } else {
181 println!(
182 "{:<width$} < {}",
183 "✔︎ Already cloned",
184 repo_url,
185 width = max_name_width
186 );
187
188 let actual_origin =
189 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
190 if actual_origin != repo_url {
191 return Err(anyhow!(
192 "remote origin {} does not match expected origin {}",
193 actual_origin,
194 repo_url,
195 ));
196 }
197 }
198 }
199 }
200
201 future::join_all(clone_tasks).await;
202
203 for example in examples.iter() {
204 example.setup().await?;
205 }
206
207 let tasks = examples
208 .into_iter()
209 .map(|example| {
210 let app_state = app_state.clone();
211 let model = model.clone();
212 cx.spawn(async move |cx| {
213 (run_example(&example, model, app_state, cx).await, example)
214 })
215 })
216 .collect::<Vec<_>>();
217
218 let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
219
220 println!("\n\n");
221 println!("========================================");
222 println!(" EVAL RESULTS ");
223 println!("========================================");
224 println!("");
225
226 let mut judge_scores = Vec::new();
227
228 for (result, example) in results {
229 match result {
230 Err(err) => {
231 println!("💥 {}{:?}", example.log_prefix, err);
232 }
233 Ok(judge_output) => {
234 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
235
236 println!(
237 "{} {}{}",
238 SCORES[judge_output.score.min(5) as usize],
239 example.log_prefix,
240 judge_output.score,
241 );
242 judge_scores.push(judge_output.score);
243 }
244 }
245 println!(
246 "{} > {}",
247 " ".repeat(max_name_width),
248 example.output_file_path.display()
249 );
250 }
251
252 let score_count = judge_scores.len();
253 let average_score = judge_scores
254 .into_iter()
255 .map(|score| score as f32)
256 .sum::<f32>()
257 / (score_count as f32);
258 println!("\nAverage score: {average_score}");
259
260 cx.update(|cx| cx.quit())
261 })
262 .detach_and_log_err(cx);
263 });
264}
265
266async fn run_example(
267 example: &Example,
268 model: Arc<dyn LanguageModel>,
269 app_state: Arc<AgentAppState>,
270 cx: &mut AsyncApp,
271) -> Result<JudgeOutput> {
272 cx.update(|cx| example.run(model.clone(), app_state, cx))?
273 .await?;
274 let diff = example.repository_diff().await?;
275 example.judge(model, diff, cx).await
276}
277
278fn list_all_examples() -> Result<Vec<PathBuf>> {
279 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
280 let entries = std::fs::read_dir(path).unwrap();
281 let mut result_paths = Vec::new();
282 for entry in entries {
283 let entry = entry?;
284 let path = entry.path();
285 if path.is_dir() {
286 result_paths.push(path);
287 }
288 }
289 Ok(result_paths)
290}
291
292/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
293pub struct AgentAppState {
294 pub languages: Arc<LanguageRegistry>,
295 pub client: Arc<Client>,
296 pub user_store: Entity<UserStore>,
297 pub fs: Arc<dyn fs::Fs>,
298 pub node_runtime: NodeRuntime,
299
300 // Additional fields not present in `workspace::AppState`.
301 pub prompt_builder: Arc<PromptBuilder>,
302}
303
304pub fn init(cx: &mut App) -> Arc<AgentAppState> {
305 release_channel::init(SemanticVersion::default(), cx);
306 gpui_tokio::init(cx);
307
308 let mut settings_store = SettingsStore::new(cx);
309 settings_store
310 .set_default_settings(settings::default_settings().as_ref(), cx)
311 .unwrap();
312 cx.set_global(settings_store);
313 client::init_settings(cx);
314
315 // Set User-Agent so we can download language servers from GitHub
316 let user_agent = format!(
317 "Zed/{} ({}; {})",
318 AppVersion::global(cx),
319 std::env::consts::OS,
320 std::env::consts::ARCH
321 );
322 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
323 let proxy_url = proxy_str
324 .as_ref()
325 .and_then(|input| input.parse::<Uri>().ok())
326 .or_else(read_proxy_from_env);
327 let http = {
328 let _guard = Tokio::handle(cx).enter();
329
330 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
331 .expect("could not start HTTP client")
332 };
333 cx.set_http_client(Arc::new(http));
334
335 Project::init_settings(cx);
336
337 let client = Client::production(cx);
338 cx.set_http_client(client.http_client().clone());
339
340 let git_binary_path = None;
341 let fs = Arc::new(RealFs::new(
342 git_binary_path,
343 cx.background_executor().clone(),
344 ));
345
346 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
347 languages.set_language_server_download_dir(paths::languages_dir().clone());
348 let languages = Arc::new(languages);
349
350 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
351
352 extension::init(cx);
353
354 let (tx, rx) = async_watch::channel(None);
355 cx.observe_global::<SettingsStore>(move |cx| {
356 let settings = &ProjectSettings::get_global(cx).node;
357 let options = NodeBinaryOptions {
358 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
359 allow_binary_download: true,
360 use_paths: settings.path.as_ref().map(|node_path| {
361 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
362 let npm_path = settings
363 .npm_path
364 .as_ref()
365 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
366 (
367 node_path.clone(),
368 npm_path.unwrap_or_else(|| {
369 let base_path = PathBuf::new();
370 node_path.parent().unwrap_or(&base_path).join("npm")
371 }),
372 )
373 }),
374 };
375 tx.send(Some(options)).log_err();
376 })
377 .detach();
378 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
379
380 let extension_host_proxy = ExtensionHostProxy::global(cx);
381
382 language::init(cx);
383 language_extension::init(extension_host_proxy.clone(), languages.clone());
384 language_model::init(client.clone(), cx);
385 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
386 languages::init(languages.clone(), node_runtime.clone(), cx);
387 assistant_tools::init(client.http_client().clone(), cx);
388 context_server::init(cx);
389 let stdout_is_a_pty = false;
390 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
391 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
392
393 AssistantSettings::override_global(
394 AssistantSettings {
395 always_allow_tool_actions: true,
396 ..AssistantSettings::get_global(cx).clone()
397 },
398 cx,
399 );
400
401 Arc::new(AgentAppState {
402 languages,
403 client,
404 user_store,
405 fs,
406 node_runtime,
407 prompt_builder,
408 })
409}
410
411pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
412 let model_registry = LanguageModelRegistry::read_global(cx);
413 let model = model_registry
414 .available_models(cx)
415 .find(|model| model.id().0 == model_name);
416
417 let Some(model) = model else {
418 return Err(anyhow!(
419 "No language model named {} was available. Available models: {}",
420 model_name,
421 model_registry
422 .available_models(cx)
423 .map(|model| model.id().0.clone())
424 .collect::<Vec<_>>()
425 .join(", ")
426 ));
427 };
428
429 Ok(model)
430}
431
432pub fn authenticate_model_provider(
433 provider_id: LanguageModelProviderId,
434 cx: &mut App,
435) -> Task<std::result::Result<(), AuthenticateError>> {
436 let model_registry = LanguageModelRegistry::read_global(cx);
437 let model_provider = model_registry.provider(&provider_id).unwrap();
438 model_provider.authenticate(cx)
439}