1mod example;
2
3use assistant_settings::AssistantSettings;
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6
7use ::fs::RealFs;
8use anyhow::{Result, anyhow};
9use clap::Parser;
10use extension::ExtensionHostProxy;
11use futures::future;
12use gpui::http_client::{Uri, read_proxy_from_env};
13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
14use gpui_tokio::Tokio;
15use language::LanguageRegistry;
16use language_model::{
17 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
18};
19use node_runtime::{NodeBinaryOptions, NodeRuntime};
20use project::Project;
21use project::project_settings::ProjectSettings;
22use prompt_store::PromptBuilder;
23use release_channel::AppVersion;
24use reqwest_client::ReqwestClient;
25use settings::{Settings, SettingsStore};
26use std::collections::HashSet;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42 /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
43 #[arg(long, value_delimiter = ',')]
44 languages: Option<Vec<String>>,
45}
46
47fn main() {
48 env_logger::init();
49
50 let args = Args::parse();
51 let all_available_examples = list_all_examples().unwrap();
52 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
53
54 let example_paths = all_available_examples
55 .iter()
56 .filter_map(|example_path| {
57 let name = example_path.file_name()?.to_string_lossy();
58 if args.examples.is_empty()
59 || args
60 .examples
61 .iter()
62 .any(|name_substring| name.contains(name_substring))
63 {
64 Some(example_path.clone())
65 } else {
66 None
67 }
68 })
69 .collect::<Vec<_>>();
70
71 let http_client = Arc::new(ReqwestClient::new());
72 let app = Application::headless().with_http_client(http_client.clone());
73
74 app.run(move |cx| {
75 let app_state = init(cx);
76
77 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
78
79 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
80 registry.set_default_model(Some(model.clone()), cx);
81 });
82
83 let model_provider_id = model.provider_id();
84
85 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
86
87 cx.spawn(async move |cx| {
88 authenticate.await.unwrap();
89
90 std::fs::create_dir_all(REPOS_DIR)?;
91 std::fs::create_dir_all(WORKTREES_DIR)?;
92
93 let run_dir = Path::new(RUNS_DIR).join(format!(
94 "{}",
95 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
96 ));
97 std::fs::create_dir_all(&run_dir)?;
98
99 let mut examples = Vec::new();
100 for example_path in example_paths {
101 let example = Example::load_from_directory(&example_path, &run_dir)?;
102
103 if !example
104 .base
105 .language_extension
106 .as_ref()
107 .map_or(false, |lang| languages.contains(lang))
108 {
109 println!("Skipping {}", example.name);
110 continue;
111 }
112
113 println!("{}> Logging to {:?}", example.name, example.log_file_path);
114
115 examples.push(example);
116 }
117 let mut repo_urls = HashSet::new();
118
119 let mut clone_tasks = Vec::new();
120
121 for example in examples.iter() {
122 let repo_url = example.base.url.clone();
123 if repo_urls.insert(repo_url.clone()) {
124 let repo_path = repo_path_for_url(&repo_url);
125
126 if !repo_path.join(".git").is_dir() {
127 println!("Cloning: {}", repo_url);
128
129 let git_task = cx.spawn(async move |_cx| {
130 std::fs::create_dir_all(&repo_path)?;
131 run_git(&repo_path, &["init"]).await?;
132 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
133 });
134
135 clone_tasks.push(git_task);
136 } else {
137 println!("Already cloned: {}", repo_url);
138
139 let actual_origin =
140 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
141 if actual_origin != repo_url {
142 return Err(anyhow!(
143 "remote origin {} does not match expected origin {}",
144 actual_origin,
145 repo_url,
146 ));
147 }
148 }
149 }
150 }
151
152 future::join_all(clone_tasks).await;
153
154 for example in examples.iter() {
155 example.setup().await?;
156 }
157
158 let tasks = examples
159 .into_iter()
160 .map(|example| {
161 let app_state = app_state.clone();
162 let model = model.clone();
163 cx.spawn(async move |cx| {
164 (run_example(&example, model, app_state, cx).await, example)
165 })
166 })
167 .collect::<Vec<_>>();
168
169 let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
170
171 println!("\n\n");
172 println!("========================================");
173 println!(" EVAL RESULTS ");
174 println!("========================================");
175 println!("");
176
177 let mut judge_scores = Vec::new();
178
179 for (result, example) in results {
180 println!("📜 {:<30}: {:?}", example.name, example.log_file_path);
181 match result {
182 Err(err) => {
183 println!("💥 {:<30}: {:?}", example.name, err);
184 }
185 Ok(judge_output) => {
186 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
187
188 println!(
189 "{} {:<30}: {}",
190 SCORES[judge_output.score.min(5) as usize],
191 example.name,
192 judge_output.score,
193 );
194 judge_scores.push(judge_output.score);
195 }
196 }
197 }
198
199 let score_count = judge_scores.len();
200 let average_score = judge_scores
201 .into_iter()
202 .map(|score| score as f32)
203 .sum::<f32>()
204 / (score_count as f32);
205 println!("\nAverage score: {average_score}");
206
207 cx.update(|cx| cx.quit())
208 })
209 .detach_and_log_err(cx);
210 });
211}
212
213async fn run_example(
214 example: &Example,
215 model: Arc<dyn LanguageModel>,
216 app_state: Arc<AgentAppState>,
217 cx: &mut AsyncApp,
218) -> Result<JudgeOutput> {
219 cx.update(|cx| example.run(model.clone(), app_state, cx))?
220 .await?;
221 let diff = example.repository_diff().await?;
222 example.judge(model, diff, cx).await
223}
224
225fn list_all_examples() -> Result<Vec<PathBuf>> {
226 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
227 let entries = std::fs::read_dir(path).unwrap();
228 let mut result_paths = Vec::new();
229 for entry in entries {
230 let entry = entry?;
231 let path = entry.path();
232 if path.is_dir() {
233 result_paths.push(path);
234 }
235 }
236 Ok(result_paths)
237}
238
239/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
240pub struct AgentAppState {
241 pub languages: Arc<LanguageRegistry>,
242 pub client: Arc<Client>,
243 pub user_store: Entity<UserStore>,
244 pub fs: Arc<dyn fs::Fs>,
245 pub node_runtime: NodeRuntime,
246
247 // Additional fields not present in `workspace::AppState`.
248 pub prompt_builder: Arc<PromptBuilder>,
249}
250
251pub fn init(cx: &mut App) -> Arc<AgentAppState> {
252 release_channel::init(SemanticVersion::default(), cx);
253 gpui_tokio::init(cx);
254
255 let mut settings_store = SettingsStore::new(cx);
256 settings_store
257 .set_default_settings(settings::default_settings().as_ref(), cx)
258 .unwrap();
259 cx.set_global(settings_store);
260 client::init_settings(cx);
261
262 // Set User-Agent so we can download language servers from GitHub
263 let user_agent = format!(
264 "Zed/{} ({}; {})",
265 AppVersion::global(cx),
266 std::env::consts::OS,
267 std::env::consts::ARCH
268 );
269 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
270 let proxy_url = proxy_str
271 .as_ref()
272 .and_then(|input| input.parse::<Uri>().ok())
273 .or_else(read_proxy_from_env);
274 let http = {
275 let _guard = Tokio::handle(cx).enter();
276
277 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
278 .expect("could not start HTTP client")
279 };
280 cx.set_http_client(Arc::new(http));
281
282 Project::init_settings(cx);
283
284 let client = Client::production(cx);
285 cx.set_http_client(client.http_client().clone());
286
287 let git_binary_path = None;
288 let fs = Arc::new(RealFs::new(
289 git_binary_path,
290 cx.background_executor().clone(),
291 ));
292
293 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
294 languages.set_language_server_download_dir(paths::languages_dir().clone());
295 let languages = Arc::new(languages);
296
297 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
298
299 extension::init(cx);
300
301 let (tx, rx) = async_watch::channel(None);
302 cx.observe_global::<SettingsStore>(move |cx| {
303 let settings = &ProjectSettings::get_global(cx).node;
304 let options = NodeBinaryOptions {
305 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
306 allow_binary_download: true,
307 use_paths: settings.path.as_ref().map(|node_path| {
308 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
309 let npm_path = settings
310 .npm_path
311 .as_ref()
312 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
313 (
314 node_path.clone(),
315 npm_path.unwrap_or_else(|| {
316 let base_path = PathBuf::new();
317 node_path.parent().unwrap_or(&base_path).join("npm")
318 }),
319 )
320 }),
321 };
322 tx.send(Some(options)).log_err();
323 })
324 .detach();
325 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
326
327 let extension_host_proxy = ExtensionHostProxy::global(cx);
328
329 language::init(cx);
330 language_extension::init(extension_host_proxy.clone(), languages.clone());
331 language_model::init(client.clone(), cx);
332 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
333 languages::init(languages.clone(), node_runtime.clone(), cx);
334 assistant_tools::init(client.http_client().clone(), cx);
335 context_server::init(cx);
336 let stdout_is_a_pty = false;
337 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
338 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
339
340 AssistantSettings::override_global(
341 AssistantSettings {
342 always_allow_tool_actions: true,
343 ..AssistantSettings::get_global(cx).clone()
344 },
345 cx,
346 );
347
348 Arc::new(AgentAppState {
349 languages,
350 client,
351 user_store,
352 fs,
353 node_runtime,
354 prompt_builder,
355 })
356}
357
358pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
359 let model_registry = LanguageModelRegistry::read_global(cx);
360 let model = model_registry
361 .available_models(cx)
362 .find(|model| model.id().0 == model_name);
363
364 let Some(model) = model else {
365 return Err(anyhow!(
366 "No language model named {} was available. Available models: {}",
367 model_name,
368 model_registry
369 .available_models(cx)
370 .map(|model| model.id().0.clone())
371 .collect::<Vec<_>>()
372 .join(", ")
373 ));
374 };
375
376 Ok(model)
377}
378
379pub fn authenticate_model_provider(
380 provider_id: LanguageModelProviderId,
381 cx: &mut App,
382) -> Task<std::result::Result<(), AuthenticateError>> {
383 let model_registry = LanguageModelRegistry::read_global(cx);
384 let model_provider = model_registry.provider(&provider_id).unwrap();
385 model_provider.authenticate(cx)
386}