1mod example;
2
3use assistant_settings::AssistantSettings;
4use client::{Client, ProxySettings, UserStore};
5pub(crate) use example::*;
6
7use ::fs::RealFs;
8use anyhow::{Result, anyhow};
9use clap::Parser;
10use extension::ExtensionHostProxy;
11use futures::future;
12use gpui::http_client::{Uri, read_proxy_from_env};
13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
14use gpui_tokio::Tokio;
15use language::LanguageRegistry;
16use language_model::{
17 AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
18};
19use node_runtime::{NodeBinaryOptions, NodeRuntime};
20use project::Project;
21use project::project_settings::ProjectSettings;
22use prompt_store::PromptBuilder;
23use release_channel::AppVersion;
24use reqwest_client::ReqwestClient;
25use settings::{Settings, SettingsStore};
26use std::collections::HashSet;
27use std::path::{Path, PathBuf};
28use std::sync::Arc;
29use util::ResultExt as _;
30
31pub const RUNS_DIR: &str = "./crates/eval/runs";
32
33#[derive(Parser, Debug)]
34#[command(name = "eval", disable_version_flag = true)]
35struct Args {
36 /// Runs all examples that contain these substrings. If unspecified, all examples are run.
37 #[arg(value_name = "EXAMPLE_SUBSTRING")]
38 examples: Vec<String>,
39 /// Model to use (default: "claude-3-7-sonnet-latest")
40 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
41 model: String,
42}
43
44fn main() {
45 env_logger::init();
46
47 let args = Args::parse();
48 let all_available_examples = list_all_examples().unwrap();
49 let example_paths = all_available_examples
50 .iter()
51 .filter_map(|example_path| {
52 let name = example_path.file_name()?.to_string_lossy();
53 if args.examples.is_empty()
54 || args
55 .examples
56 .iter()
57 .any(|name_substring| name.contains(name_substring))
58 {
59 Some(example_path.clone())
60 } else {
61 None
62 }
63 })
64 .collect::<Vec<_>>();
65
66 let http_client = Arc::new(ReqwestClient::new());
67 let app = Application::headless().with_http_client(http_client.clone());
68
69 app.run(move |cx| {
70 let app_state = init(cx);
71
72 let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
73
74 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
75 registry.set_default_model(Some(model.clone()), cx);
76 });
77
78 let model_provider_id = model.provider_id();
79
80 let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
81
82 cx.spawn(async move |cx| {
83 authenticate.await.unwrap();
84
85 std::fs::create_dir_all(REPOS_DIR)?;
86 std::fs::create_dir_all(WORKTREES_DIR)?;
87
88 let run_dir = Path::new(RUNS_DIR).join(format!(
89 "{}",
90 chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
91 ));
92 std::fs::create_dir_all(&run_dir)?;
93
94 let mut examples = Vec::new();
95 for example_path in example_paths {
96 let example = Example::load_from_directory(&example_path, &run_dir)?;
97 examples.push((example_path, example));
98 }
99 let mut repo_urls = HashSet::new();
100
101 let mut clone_tasks = Vec::new();
102
103 for (_, example) in examples.iter() {
104 let repo_url = example.base.url.clone();
105 if repo_urls.insert(repo_url.clone()) {
106 let repo_path = repo_path_for_url(&repo_url);
107
108 if !repo_path.join(".git").is_dir() {
109 println!("Cloning: {}", repo_url);
110
111 let git_task = cx.spawn(async move |_cx| {
112 std::fs::create_dir_all(&repo_path)?;
113 run_git(&repo_path, &["init"]).await?;
114 run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
115 });
116
117 clone_tasks.push(git_task);
118 } else {
119 println!("Already cloned: {}", repo_url);
120
121 let actual_origin =
122 run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
123 if actual_origin != repo_url {
124 return Err(anyhow!(
125 "remote origin {} does not match expected origin {}",
126 actual_origin,
127 repo_url,
128 ));
129 }
130 }
131 }
132 }
133
134 future::join_all(clone_tasks).await;
135
136 let tasks = examples
137 .into_iter()
138 .map(|(example_path, example)| {
139 let app_state = app_state.clone();
140 let model = model.clone();
141 cx.spawn(async move |cx| {
142 (
143 example_path,
144 run_example(example, model, app_state, cx).await,
145 )
146 })
147 })
148 .collect::<Vec<_>>();
149
150 let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
151
152 println!("\n\n");
153 println!("========================================");
154 println!(" EVAL RESULTS ");
155 println!("========================================");
156 println!("");
157
158 let mut judge_scores = Vec::new();
159
160 for (example_path, result) in results {
161 let example_name = example_path.file_name().unwrap().to_string_lossy();
162 match result {
163 Err(err) => {
164 println!("💥 {:<30}: {:?}", example_name, err);
165 }
166 Ok(judge_output) => {
167 const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
168
169 println!(
170 "{} {:<30}: {}",
171 SCORES[judge_output.score.min(5) as usize],
172 example_name,
173 judge_output.score,
174 );
175 judge_scores.push(judge_output.score);
176 }
177 }
178 }
179
180 let score_count = judge_scores.len();
181 let average_score = judge_scores
182 .into_iter()
183 .map(|score| score as f32)
184 .sum::<f32>()
185 / (score_count as f32);
186 println!("\nAverage score: {average_score}");
187
188 cx.update(|cx| cx.quit())
189 })
190 .detach_and_log_err(cx);
191 });
192}
193
194async fn run_example(
195 mut example: Example,
196 model: Arc<dyn LanguageModel>,
197 app_state: Arc<AgentAppState>,
198 cx: &mut AsyncApp,
199) -> Result<JudgeOutput> {
200 example.setup().await?;
201 cx.update(|cx| example.run(model.clone(), app_state, cx))?
202 .await?;
203 let diff = example.repository_diff().await?;
204 example.judge(model, diff, cx).await
205}
206
207fn list_all_examples() -> Result<Vec<PathBuf>> {
208 let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
209 let entries = std::fs::read_dir(path).unwrap();
210 let mut result_paths = Vec::new();
211 for entry in entries {
212 let entry = entry?;
213 let path = entry.path();
214 if path.is_dir() {
215 result_paths.push(path);
216 }
217 }
218 Ok(result_paths)
219}
220
221/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
222pub struct AgentAppState {
223 pub languages: Arc<LanguageRegistry>,
224 pub client: Arc<Client>,
225 pub user_store: Entity<UserStore>,
226 pub fs: Arc<dyn fs::Fs>,
227 pub node_runtime: NodeRuntime,
228
229 // Additional fields not present in `workspace::AppState`.
230 pub prompt_builder: Arc<PromptBuilder>,
231}
232
233pub fn init(cx: &mut App) -> Arc<AgentAppState> {
234 release_channel::init(SemanticVersion::default(), cx);
235 gpui_tokio::init(cx);
236
237 let mut settings_store = SettingsStore::new(cx);
238 settings_store
239 .set_default_settings(settings::default_settings().as_ref(), cx)
240 .unwrap();
241 cx.set_global(settings_store);
242 client::init_settings(cx);
243
244 // Set User-Agent so we can download language servers from GitHub
245 let user_agent = format!(
246 "Zed/{} ({}; {})",
247 AppVersion::global(cx),
248 std::env::consts::OS,
249 std::env::consts::ARCH
250 );
251 let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
252 let proxy_url = proxy_str
253 .as_ref()
254 .and_then(|input| input.parse::<Uri>().ok())
255 .or_else(read_proxy_from_env);
256 let http = {
257 let _guard = Tokio::handle(cx).enter();
258
259 ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
260 .expect("could not start HTTP client")
261 };
262 cx.set_http_client(Arc::new(http));
263
264 Project::init_settings(cx);
265
266 let client = Client::production(cx);
267 cx.set_http_client(client.http_client().clone());
268
269 let git_binary_path = None;
270 let fs = Arc::new(RealFs::new(
271 git_binary_path,
272 cx.background_executor().clone(),
273 ));
274
275 let mut languages = LanguageRegistry::new(cx.background_executor().clone());
276 languages.set_language_server_download_dir(paths::languages_dir().clone());
277 let languages = Arc::new(languages);
278
279 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
280
281 extension::init(cx);
282
283 let (tx, rx) = async_watch::channel(None);
284 cx.observe_global::<SettingsStore>(move |cx| {
285 let settings = &ProjectSettings::get_global(cx).node;
286 let options = NodeBinaryOptions {
287 allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
288 allow_binary_download: true,
289 use_paths: settings.path.as_ref().map(|node_path| {
290 let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
291 let npm_path = settings
292 .npm_path
293 .as_ref()
294 .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
295 (
296 node_path.clone(),
297 npm_path.unwrap_or_else(|| {
298 let base_path = PathBuf::new();
299 node_path.parent().unwrap_or(&base_path).join("npm")
300 }),
301 )
302 }),
303 };
304 tx.send(Some(options)).log_err();
305 })
306 .detach();
307 let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
308
309 let extension_host_proxy = ExtensionHostProxy::global(cx);
310
311 language::init(cx);
312 language_extension::init(extension_host_proxy.clone(), languages.clone());
313 language_model::init(client.clone(), cx);
314 language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
315 languages::init(languages.clone(), node_runtime.clone(), cx);
316 assistant_tools::init(client.http_client().clone(), cx);
317 context_server::init(cx);
318 let stdout_is_a_pty = false;
319 let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
320 agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
321
322 AssistantSettings::override_global(
323 AssistantSettings {
324 always_allow_tool_actions: true,
325 ..AssistantSettings::get_global(cx).clone()
326 },
327 cx,
328 );
329
330 Arc::new(AgentAppState {
331 languages,
332 client,
333 user_store,
334 fs,
335 node_runtime,
336 prompt_builder,
337 })
338}
339
340pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
341 let model_registry = LanguageModelRegistry::read_global(cx);
342 let model = model_registry
343 .available_models(cx)
344 .find(|model| model.id().0 == model_name);
345
346 let Some(model) = model else {
347 return Err(anyhow!(
348 "No language model named {} was available. Available models: {}",
349 model_name,
350 model_registry
351 .available_models(cx)
352 .map(|model| model.id().0.clone())
353 .collect::<Vec<_>>()
354 .join(", ")
355 ));
356 };
357
358 Ok(model)
359}
360
361pub fn authenticate_model_provider(
362 provider_id: LanguageModelProviderId,
363 cx: &mut App,
364) -> Task<std::result::Result<(), AuthenticateError>> {
365 let model_registry = LanguageModelRegistry::read_global(cx);
366 let model_provider = model_registry.provider(&provider_id).unwrap();
367 model_provider.authenticate(cx)
368}