1mod example;
2
3use assistant_settings::AssistantSettings;
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6
7use ::fs::RealFs;
8use anyhow::{Result, anyhow};
9use clap::Parser;
10use extension::ExtensionHostProxy;
11use futures::future;
12use gpui::http_client::{Uri, read_proxy_from_env};
13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
14use gpui_tokio::Tokio;
15use language::LanguageRegistry;
16use language_model::{
17 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
18};
19use node_runtime::{NodeBinaryOptions, NodeRuntime};
20use project::Project;
21use project::project_settings::ProjectSettings;
22use prompt_store::PromptBuilder;
23use release_channel::AppVersion;
24use reqwest_client::ReqwestClient;
25use settings::{Settings, SettingsStore};
26use std::collections::HashSet;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42 /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
43 #[arg(long, value_delimiter = ',')]
44 languages: Option<Vec<String>>,
45}
46
47fn main() {
48 env_logger::init();
49
50 let args = Args::parse();
51 let all_available_examples = list_all_examples().unwrap();
52 let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
53
54 let example_paths = all_available_examples
55 .iter()
56 .filter_map(|example_path| {
57 let name = example_path.file_name()?.to_string_lossy();
58 if args.examples.is_empty()
59 || args
60 .examples
61 .iter()
62 .any(|name_substring| name.contains(name_substring))
63 {
64 Some(example_path.clone())
65 } else {
66 None
67 }
68 })
69 .collect::<Vec<_>>();
70
71 let http_client = Arc::new(ReqwestClient::new());
72 let app = Application::headless().with_http_client(http_client.clone());
73
74 app.run(move |cx| {
75 let app_state = init(cx);
76
77 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
78
79 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
80 registry.set_default_model(Some(model.clone()), cx);
81 });
82
83 let model_provider_id = model.provider_id();
84
85 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
86
87 cx.spawn(async move |cx| {
88 authenticate.await.unwrap();
89
90 std::fs::create_dir_all(REPOS_DIR)?;
91 std::fs::create_dir_all(WORKTREES_DIR)?;
92
93 let run_dir = Path::new(RUNS_DIR).join(format!(
94 "{}",
95 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
96 ));
97 std::fs::create_dir_all(&run_dir)?;
98
99 let mut examples = Vec::new();
100 for example_path in example_paths {
101 let example = Example::load_from_directory(&example_path, &run_dir)?;
102
103 if !example
104 .base
105 .language_extension
106 .as_ref()
107 .map_or(false, |lang| languages.contains(lang))
108 {
109 println!("Skipping {}", example.name);
110 continue;
111 }
112
113 examples.push((example_path, example));
114 }
115 let mut repo_urls = HashSet::new();
116
117 let mut clone_tasks = Vec::new();
118
119 for (_, example) in examples.iter() {
120 let repo_url = example.base.url.clone();
121 if repo_urls.insert(repo_url.clone()) {
122 let repo_path = repo_path_for_url(&repo_url);
123
124 if !repo_path.join(".git").is_dir() {
125 println!("Cloning: {}", repo_url);
126
127 let git_task = cx.spawn(async move |_cx| {
128 std::fs::create_dir_all(&repo_path)?;
129 run_git(&repo_path, &["init"]).await?;
130 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
131 });
132
133 clone_tasks.push(git_task);
134 } else {
135 println!("Already cloned: {}", repo_url);
136
137 let actual_origin =
138 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
139 if actual_origin != repo_url {
140 return Err(anyhow!(
141 "remote origin {} does not match expected origin {}",
142 actual_origin,
143 repo_url,
144 ));
145 }
146 }
147 }
148 }
149
150 future::join_all(clone_tasks).await;
151
152 for (_, example) in examples.iter() {
153 example.setup().await?;
154 }
155
156 let tasks = examples
157 .into_iter()
158 .map(|(example_path, example)| {
159 let app_state = app_state.clone();
160 let model = model.clone();
161 cx.spawn(async move |cx| {
162 (
163 example_path,
164 run_example(example, model, app_state, cx).await,
165 )
166 })
167 })
168 .collect::<Vec<_>>();
169
170 let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
171
172 println!("\n\n");
173 println!("========================================");
174 println!(" EVAL RESULTS ");
175 println!("========================================");
176 println!("");
177
178 let mut judge_scores = Vec::new();
179
180 for (example_path, result) in results {
181 let example_name = example_path.file_name().unwrap().to_string_lossy();
182 match result {
183 Err(err) => {
184 println!("💥 {:<30}: {:?}", example_name, err);
185 }
186 Ok(judge_output) => {
187 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
188
189 println!(
190 "{} {:<30}: {}",
191 SCORES[judge_output.score.min(5) as usize],
192 example_name,
193 judge_output.score,
194 );
195 judge_scores.push(judge_output.score);
196 }
197 }
198 }
199
200 let score_count = judge_scores.len();
201 let average_score = judge_scores
202 .into_iter()
203 .map(|score| score as f32)
204 .sum::<f32>()
205 / (score_count as f32);
206 println!("\nAverage score: {average_score}");
207
208 cx.update(|cx| cx.quit())
209 })
210 .detach_and_log_err(cx);
211 });
212}
213
214async fn run_example(
215 mut example: Example,
216 model: Arc<dyn LanguageModel>,
217 app_state: Arc<AgentAppState>,
218 cx: &mut AsyncApp,
219) -> Result<JudgeOutput> {
220 cx.update(|cx| example.run(model.clone(), app_state, cx))?
221 .await?;
222 let diff = example.repository_diff().await?;
223 example.judge(model, diff, cx).await
224}
225
226fn list_all_examples() -> Result<Vec<PathBuf>> {
227 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
228 let entries = std::fs::read_dir(path).unwrap();
229 let mut result_paths = Vec::new();
230 for entry in entries {
231 let entry = entry?;
232 let path = entry.path();
233 if path.is_dir() {
234 result_paths.push(path);
235 }
236 }
237 Ok(result_paths)
238}
239
240/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
241pub struct AgentAppState {
242 pub languages: Arc<LanguageRegistry>,
243 pub client: Arc<Client>,
244 pub user_store: Entity<UserStore>,
245 pub fs: Arc<dyn fs::Fs>,
246 pub node_runtime: NodeRuntime,
247
248 // Additional fields not present in `workspace::AppState`.
249 pub prompt_builder: Arc<PromptBuilder>,
250}
251
252pub fn init(cx: &mut App) -> Arc<AgentAppState> {
253 release_channel::init(SemanticVersion::default(), cx);
254 gpui_tokio::init(cx);
255
256 let mut settings_store = SettingsStore::new(cx);
257 settings_store
258 .set_default_settings(settings::default_settings().as_ref(), cx)
259 .unwrap();
260 cx.set_global(settings_store);
261 client::init_settings(cx);
262
263 // Set User-Agent so we can download language servers from GitHub
264 let user_agent = format!(
265 "Zed/{} ({}; {})",
266 AppVersion::global(cx),
267 std::env::consts::OS,
268 std::env::consts::ARCH
269 );
270 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
271 let proxy_url = proxy_str
272 .as_ref()
273 .and_then(|input| input.parse::<Uri>().ok())
274 .or_else(read_proxy_from_env);
275 let http = {
276 let _guard = Tokio::handle(cx).enter();
277
278 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
279 .expect("could not start HTTP client")
280 };
281 cx.set_http_client(Arc::new(http));
282
283 Project::init_settings(cx);
284
285 let client = Client::production(cx);
286 cx.set_http_client(client.http_client().clone());
287
288 let git_binary_path = None;
289 let fs = Arc::new(RealFs::new(
290 git_binary_path,
291 cx.background_executor().clone(),
292 ));
293
294 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
295 languages.set_language_server_download_dir(paths::languages_dir().clone());
296 let languages = Arc::new(languages);
297
298 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
299
300 extension::init(cx);
301
302 let (tx, rx) = async_watch::channel(None);
303 cx.observe_global::<SettingsStore>(move |cx| {
304 let settings = &ProjectSettings::get_global(cx).node;
305 let options = NodeBinaryOptions {
306 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
307 allow_binary_download: true,
308 use_paths: settings.path.as_ref().map(|node_path| {
309 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
310 let npm_path = settings
311 .npm_path
312 .as_ref()
313 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
314 (
315 node_path.clone(),
316 npm_path.unwrap_or_else(|| {
317 let base_path = PathBuf::new();
318 node_path.parent().unwrap_or(&base_path).join("npm")
319 }),
320 )
321 }),
322 };
323 tx.send(Some(options)).log_err();
324 })
325 .detach();
326 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
327
328 let extension_host_proxy = ExtensionHostProxy::global(cx);
329
330 language::init(cx);
331 language_extension::init(extension_host_proxy.clone(), languages.clone());
332 language_model::init(client.clone(), cx);
333 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
334 languages::init(languages.clone(), node_runtime.clone(), cx);
335 assistant_tools::init(client.http_client().clone(), cx);
336 context_server::init(cx);
337 let stdout_is_a_pty = false;
338 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
339 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
340
341 AssistantSettings::override_global(
342 AssistantSettings {
343 always_allow_tool_actions: true,
344 ..AssistantSettings::get_global(cx).clone()
345 },
346 cx,
347 );
348
349 Arc::new(AgentAppState {
350 languages,
351 client,
352 user_store,
353 fs,
354 node_runtime,
355 prompt_builder,
356 })
357}
358
359pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
360 let model_registry = LanguageModelRegistry::read_global(cx);
361 let model = model_registry
362 .available_models(cx)
363 .find(|model| model.id().0 == model_name);
364
365 let Some(model) = model else {
366 return Err(anyhow!(
367 "No language model named {} was available. Available models: {}",
368 model_name,
369 model_registry
370 .available_models(cx)
371 .map(|model| model.id().0.clone())
372 .collect::<Vec<_>>()
373 .join(", ")
374 ));
375 };
376
377 Ok(model)
378}
379
380pub fn authenticate_model_provider(
381 provider_id: LanguageModelProviderId,
382 cx: &mut App,
383) -> Task<std::result::Result<(), AuthenticateError>> {
384 let model_registry = LanguageModelRegistry::read_global(cx);
385 let model_provider = model_registry.provider(&provider_id).unwrap();
386 model_provider.authenticate(cx)
387}