eval.rs

  1mod example;
  2
  3use client::{Client, ProxySettings, UserStore};
  4pub(crate) use example::*;
  5
  6use ::fs::RealFs;
  7use anyhow::{Result, anyhow};
  8use clap::Parser;
  9use extension::ExtensionHostProxy;
 10use futures::future;
 11use gpui::http_client::{Uri, read_proxy_from_env};
 12use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 13use gpui_tokio::Tokio;
 14use language::LanguageRegistry;
 15use language_model::{
 16    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 17};
 18use node_runtime::{NodeBinaryOptions, NodeRuntime};
 19use project::Project;
 20use project::project_settings::ProjectSettings;
 21use prompt_store::PromptBuilder;
 22use release_channel::AppVersion;
 23use reqwest_client::ReqwestClient;
 24use settings::{Settings, SettingsStore};
 25use std::collections::HashSet;
 26use std::path::{Path, PathBuf};
 27use std::sync::Arc;
 28use std::usize;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42    /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
 43    #[arg(long, value_delimiter = ',')]
 44    languages: Option<Vec<String>>,
 45}
 46
 47fn main() {
 48    env_logger::init();
 49
 50    let args = Args::parse();
 51    let all_available_examples = list_all_examples().unwrap();
 52    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 53
 54    let example_paths = all_available_examples
 55        .iter()
 56        .filter_map(|example_path| {
 57            let name = example_path.file_name()?.to_string_lossy();
 58            if args.examples.is_empty()
 59                || args
 60                    .examples
 61                    .iter()
 62                    .any(|name_substring| name.contains(name_substring))
 63            {
 64                Some(example_path.clone())
 65            } else {
 66                None
 67            }
 68        })
 69        .collect::<Vec<_>>();
 70
 71    let http_client = Arc::new(ReqwestClient::new());
 72    let app = Application::headless().with_http_client(http_client.clone());
 73
 74    app.run(move |cx| {
 75        let app_state = init(cx);
 76
 77        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 78
 79        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 80            registry.set_default_model(Some(model.clone()), cx);
 81        });
 82
 83        let model_provider_id = model.provider_id();
 84
 85        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 86
 87        cx.spawn(async move |cx| {
 88            authenticate.await.unwrap();
 89
 90            std::fs::create_dir_all(REPOS_DIR)?;
 91            std::fs::create_dir_all(WORKTREES_DIR)?;
 92
 93            let run_dir = Path::new(RUNS_DIR).join(format!(
 94                "{}",
 95                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 96            ));
 97            std::fs::create_dir_all(&run_dir)?;
 98
 99            let mut examples = Vec::new();
100
101            const COLORS: [&str; 12] = [
102                "\x1b[31m", // Red
103                "\x1b[32m", // Green
104                "\x1b[33m", // Yellow
105                "\x1b[34m", // Blue
106                "\x1b[35m", // Magenta
107                "\x1b[36m", // Cyan
108                "\x1b[91m", // Bright Red
109                "\x1b[92m", // Bright Green
110                "\x1b[93m", // Bright Yellow
111                "\x1b[94m", // Bright Blue
112                "\x1b[95m", // Bright Magenta
113                "\x1b[96m", // Bright Cyan
114            ];
115
116            let mut max_name_width = 0;
117            let mut skipped = Vec::new();
118
119            for example_path in &example_paths {
120                let example = Example::load_from_directory(example_path, &run_dir)?;
121
122                if !example
123                    .base
124                    .language_extension
125                    .as_ref()
126                    .map_or(false, |lang| languages.contains(lang))
127                {
128                    skipped.push(example.name);
129                    continue;
130                }
131
132                let name_len = example.name.len();
133                if name_len > max_name_width {
134                    max_name_width = example.name.len();
135                }
136
137                examples.push(example);
138            }
139
140            println!("Skipped examples: {}\n", skipped.join(", "));
141
142            if examples.is_empty() {
143                eprintln!("Filter matched no examples");
144                return cx.update(|cx| cx.quit());
145            }
146
147            let mut repo_urls = HashSet::new();
148            let mut clone_tasks = Vec::new();
149
150            for (i, example) in examples.iter_mut().enumerate() {
151                let color = COLORS[i % COLORS.len()].to_string();
152                example.set_log_prefix_style(&color, max_name_width);
153
154                println!(
155                    "{}Logging to: {}",
156                    example.log_prefix,
157                    example.output_file_path.display()
158                );
159
160                let repo_url = example.base.url.clone();
161                if repo_urls.insert(repo_url.clone()) {
162                    let repo_path = repo_path_for_url(&repo_url);
163
164                    if !repo_path.join(".git").is_dir() {
165                        println!(
166                            "{:<width$}  < {}",
167                            "↓ Cloning",
168                            repo_url,
169                            width = max_name_width
170                        );
171
172                        let git_task = cx.spawn(async move |_cx| {
173                            std::fs::create_dir_all(&repo_path)?;
174                            run_git(&repo_path, &["init"]).await?;
175                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
176                        });
177
178                        clone_tasks.push(git_task);
179                    } else {
180                        println!(
181                            "{:<width$}  < {}",
182                            "✔︎ Already cloned",
183                            repo_url,
184                            width = max_name_width
185                        );
186
187                        let actual_origin =
188                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
189                        if actual_origin != repo_url {
190                            return Err(anyhow!(
191                                "remote origin {} does not match expected origin {}",
192                                actual_origin,
193                                repo_url,
194                            ));
195                        }
196                    }
197                }
198            }
199
200            future::join_all(clone_tasks).await;
201
202            for example in examples.iter_mut() {
203                example.setup().await?;
204            }
205
206            let tasks = examples
207                .into_iter()
208                .map(|example| {
209                    let app_state = app_state.clone();
210                    let model = model.clone();
211                    cx.spawn(async move |cx| {
212                        (run_example(&example, model, app_state, cx).await, example)
213                    })
214                })
215                .collect::<Vec<_>>();
216
217            let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
218
219            println!("\n\n");
220            println!("========================================");
221            println!("              EVAL RESULTS              ");
222            println!("========================================");
223            println!("");
224
225            let mut judge_scores = Vec::new();
226
227            for (result, example) in results {
228                match result {
229                    Err(err) => {
230                        println!("💥 {}{:?}", example.log_prefix, err);
231                    }
232                    Ok(judge_output) => {
233                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
234
235                        println!(
236                            "{} {}{}",
237                            SCORES[judge_output.score.min(5) as usize],
238                            example.log_prefix,
239                            judge_output.score,
240                        );
241                        judge_scores.push(judge_output.score);
242                    }
243                }
244                println!(
245                    "{}    > {}",
246                    " ".repeat(max_name_width),
247                    example.output_file_path.display()
248                );
249            }
250
251            let score_count = judge_scores.len();
252            let average_score = judge_scores
253                .into_iter()
254                .map(|score| score as f32)
255                .sum::<f32>()
256                / (score_count as f32);
257            println!("\nAverage score: {average_score}");
258
259            cx.update(|cx| cx.quit())
260        })
261        .detach_and_log_err(cx);
262    });
263}
264
265async fn run_example(
266    example: &Example,
267    model: Arc<dyn LanguageModel>,
268    app_state: Arc<AgentAppState>,
269    cx: &mut AsyncApp,
270) -> Result<JudgeOutput> {
271    cx.update(|cx| example.run(model.clone(), app_state, cx))?
272        .await?;
273    let diff = example.repository_diff().await?;
274    example.judge(model, diff, cx).await
275}
276
277fn list_all_examples() -> Result<Vec<PathBuf>> {
278    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
279    let entries = std::fs::read_dir(path).unwrap();
280    let mut result_paths = Vec::new();
281    for entry in entries {
282        let entry = entry?;
283        let path = entry.path();
284        if path.is_dir() {
285            result_paths.push(path);
286        }
287    }
288    Ok(result_paths)
289}
290
291/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
292pub struct AgentAppState {
293    pub languages: Arc<LanguageRegistry>,
294    pub client: Arc<Client>,
295    pub user_store: Entity<UserStore>,
296    pub fs: Arc<dyn fs::Fs>,
297    pub node_runtime: NodeRuntime,
298
299    // Additional fields not present in `workspace::AppState`.
300    pub prompt_builder: Arc<PromptBuilder>,
301}
302
303pub fn init(cx: &mut App) -> Arc<AgentAppState> {
304    release_channel::init(SemanticVersion::default(), cx);
305    gpui_tokio::init(cx);
306
307    let mut settings_store = SettingsStore::new(cx);
308    settings_store
309        .set_default_settings(settings::default_settings().as_ref(), cx)
310        .unwrap();
311    cx.set_global(settings_store);
312    client::init_settings(cx);
313
314    // Set User-Agent so we can download language servers from GitHub
315    let user_agent = format!(
316        "Zed/{} ({}; {})",
317        AppVersion::global(cx),
318        std::env::consts::OS,
319        std::env::consts::ARCH
320    );
321    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
322    let proxy_url = proxy_str
323        .as_ref()
324        .and_then(|input| input.parse::<Uri>().ok())
325        .or_else(read_proxy_from_env);
326    let http = {
327        let _guard = Tokio::handle(cx).enter();
328
329        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
330            .expect("could not start HTTP client")
331    };
332    cx.set_http_client(Arc::new(http));
333
334    Project::init_settings(cx);
335
336    let client = Client::production(cx);
337    cx.set_http_client(client.http_client().clone());
338
339    let git_binary_path = None;
340    let fs = Arc::new(RealFs::new(
341        git_binary_path,
342        cx.background_executor().clone(),
343    ));
344
345    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
346    languages.set_language_server_download_dir(paths::languages_dir().clone());
347    let languages = Arc::new(languages);
348
349    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
350
351    extension::init(cx);
352
353    let (tx, rx) = async_watch::channel(None);
354    cx.observe_global::<SettingsStore>(move |cx| {
355        let settings = &ProjectSettings::get_global(cx).node;
356        let options = NodeBinaryOptions {
357            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
358            allow_binary_download: true,
359            use_paths: settings.path.as_ref().map(|node_path| {
360                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
361                let npm_path = settings
362                    .npm_path
363                    .as_ref()
364                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
365                (
366                    node_path.clone(),
367                    npm_path.unwrap_or_else(|| {
368                        let base_path = PathBuf::new();
369                        node_path.parent().unwrap_or(&base_path).join("npm")
370                    }),
371                )
372            }),
373        };
374        tx.send(Some(options)).log_err();
375    })
376    .detach();
377    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
378
379    let extension_host_proxy = ExtensionHostProxy::global(cx);
380
381    language::init(cx);
382    language_extension::init(extension_host_proxy.clone(), languages.clone());
383    language_model::init(client.clone(), cx);
384    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
385    languages::init(languages.clone(), node_runtime.clone(), cx);
386    assistant_tools::init(client.http_client().clone(), cx);
387    context_server::init(cx);
388    let stdout_is_a_pty = false;
389    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
390    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
391
392    SettingsStore::update_global(cx, |store, cx| {
393        store.set_user_settings(include_str!("../runner_settings.json"), cx)
394    })
395    .unwrap();
396
397    Arc::new(AgentAppState {
398        languages,
399        client,
400        user_store,
401        fs,
402        node_runtime,
403        prompt_builder,
404    })
405}
406
407pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
408    let model_registry = LanguageModelRegistry::read_global(cx);
409    let model = model_registry
410        .available_models(cx)
411        .find(|model| model.id().0 == model_name);
412
413    let Some(model) = model else {
414        return Err(anyhow!(
415            "No language model named {} was available. Available models: {}",
416            model_name,
417            model_registry
418                .available_models(cx)
419                .map(|model| model.id().0.clone())
420                .collect::<Vec<_>>()
421                .join(", ")
422        ));
423    };
424
425    Ok(model)
426}
427
428pub fn authenticate_model_provider(
429    provider_id: LanguageModelProviderId,
430    cx: &mut App,
431) -> Task<std::result::Result<(), AuthenticateError>> {
432    let model_registry = LanguageModelRegistry::read_global(cx);
433    let model_provider = model_registry.provider(&provider_id).unwrap();
434    model_provider.authenticate(cx)
435}