eval.rs

  1mod assertions;
  2mod example;
  3mod examples;
  4mod explorer;
  5mod ids;
  6mod instance;
  7mod tool_metrics;
  8
  9use assertions::display_error_row;
 10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
 11pub(crate) use tool_metrics::*;
 12
 13use ::fs::RealFs;
 14use anyhow::anyhow;
 15use clap::Parser;
 16use client::{Client, ProxySettings, UserStore};
 17use collections::{HashMap, HashSet};
 18use extension::ExtensionHostProxy;
 19use futures::future;
 20use gpui::http_client::read_proxy_from_env;
 21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 22use gpui_tokio::Tokio;
 23use language::LanguageRegistry;
 24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 25use node_runtime::{NodeBinaryOptions, NodeRuntime};
 26use project::Project;
 27use project::project_settings::ProjectSettings;
 28use prompt_store::PromptBuilder;
 29use release_channel::AppVersion;
 30use reqwest_client::ReqwestClient;
 31use settings::{Settings, SettingsStore};
 32use std::cell::RefCell;
 33use std::collections::VecDeque;
 34use std::env;
 35use std::path::{Path, PathBuf};
 36use std::rc::Rc;
 37use std::sync::{Arc, LazyLock};
 38use util::ResultExt as _;
 39
 40static CARGO_MANIFEST_DIR: LazyLock<PathBuf> =
 41    LazyLock::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")));
 42
 43#[derive(Parser, Debug)]
 44#[command(name = "eval", disable_version_flag = true)]
 45struct Args {
 46    /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
 47    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 48    filter: Vec<String>,
 49    /// Model to use (default: "claude-3-7-sonnet-latest")
 50    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 51    model: String,
 52    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 53    languages: Vec<String>,
 54    /// How many times to run each example.
 55    #[arg(long, default_value = "1")]
 56    repetitions: usize,
 57    /// Maximum number of examples to run concurrently.
 58    #[arg(long, default_value = "10")]
 59    concurrency: usize,
 60}
 61
 62fn main() {
 63    dotenv::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
 64
 65    env_logger::init();
 66
 67    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 68    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 69    let session_id = uuid::Uuid::new_v4().to_string();
 70    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 71    let run_id = match env::var("GITHUB_RUN_ID") {
 72        Ok(run_id) => format!("github/{}", run_id),
 73        Err(_) => format!("local/{}", run_timestamp),
 74    };
 75
 76    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 77        .parent()
 78        .unwrap()
 79        .parent()
 80        .unwrap()
 81        .canonicalize()
 82        .unwrap();
 83    let eval_crate_dir = root_dir.join("crates").join("eval");
 84    let repos_dir = eval_crate_dir.join("repos");
 85    let worktrees_dir = eval_crate_dir.join("worktrees");
 86    let examples_dir = eval_crate_dir.join("src").join("examples");
 87    let run_dir = eval_crate_dir
 88        .join("runs")
 89        .join(format!("{}", run_timestamp));
 90    std::fs::create_dir_all(&run_dir).unwrap();
 91    std::fs::create_dir_all(&repos_dir).unwrap();
 92    std::fs::create_dir_all(&worktrees_dir).unwrap();
 93    std::fs::create_dir_all(&examples_dir).unwrap();
 94    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 95
 96    let zed_commit_sha = commit_sha_for_path(&root_dir);
 97    let zed_branch_name = git_branch_for_path(&root_dir);
 98    let args = Args::parse();
 99    let languages: HashSet<String> = args.languages.into_iter().collect();
100
101    let http_client = Arc::new(ReqwestClient::new());
102    let app = Application::headless().with_http_client(http_client.clone());
103    let all_threads = examples::all(&examples_dir);
104
105    app.run(move |cx| {
106        let app_state = init(cx);
107
108        let telemetry = app_state.client.telemetry();
109        telemetry.start(system_id, installation_id, session_id, cx);
110
111        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
112            && telemetry.has_checksum_seed();
113        if enable_telemetry {
114            println!("Telemetry enabled");
115            telemetry::event!(
116                "Agent Eval Started",
117                zed_commit_sha = zed_commit_sha,
118                zed_branch_name = zed_branch_name,
119                run_id = run_id,
120            );
121        }
122
123        let mut cumulative_tool_metrics = ToolMetrics::default();
124
125        let model_registry = LanguageModelRegistry::read_global(cx);
126        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
127        let model_provider_id = model.provider_id();
128        let model_provider = model_registry.provider(&model_provider_id).unwrap();
129
130        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
131            registry.set_default_model(
132                Some(ConfiguredModel {
133                    provider: model_provider.clone(),
134                    model: model.clone(),
135                }),
136                cx,
137            );
138        });
139
140        let authenticate_task = model_provider.authenticate(cx);
141
142        cx.spawn(async move |cx| {
143            authenticate_task.await.unwrap();
144
145            let mut examples = Vec::new();
146
147            const COLORS: [&str; 12] = [
148                "\x1b[31m", // Red
149                "\x1b[32m", // Green
150                "\x1b[33m", // Yellow
151                "\x1b[34m", // Blue
152                "\x1b[35m", // Magenta
153                "\x1b[36m", // Cyan
154                "\x1b[91m", // Bright Red
155                "\x1b[92m", // Bright Green
156                "\x1b[93m", // Bright Yellow
157                "\x1b[94m", // Bright Blue
158                "\x1b[95m", // Bright Magenta
159                "\x1b[96m", // Bright Cyan
160            ];
161
162            let mut skipped = Vec::new();
163
164            for thread in all_threads {
165                let meta = thread.meta();
166                if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
167                {
168                    skipped.push(meta.name);
169                    continue;
170                }
171
172                if meta.language_server.map_or(false, |language| {
173                    !languages.contains(&language.file_extension)
174                }) {
175                    skipped.push(meta.name);
176                    continue;
177                }
178
179                // TODO: This creates a worktree per repetition. Ideally these examples should
180                // either be run sequentially on the same worktree, or reuse worktrees when there
181                // are more examples to run than the concurrency limit.
182                for repetition_number in 0..args.repetitions {
183                    let example_instance = ExampleInstance::new(
184                        thread.clone(),
185                        &repos_dir,
186                        &run_dir,
187                        &worktrees_dir,
188                        repetition_number,
189                    );
190
191                    examples.push(example_instance);
192                }
193            }
194
195            if !skipped.is_empty() {
196                println!("Skipped threads: {}", skipped.join(", "));
197            }
198
199            if examples.is_empty() {
200                eprintln!("Filter matched no examples");
201                return cx.update(|cx| cx.quit());
202            }
203
204            let mut repo_urls = HashSet::default();
205            let mut clone_tasks = Vec::new();
206
207            let max_name_width = examples
208                .iter()
209                .map(|e| e.worktree_name().len())
210                .max()
211                .unwrap_or(0);
212
213            for (i, example_instance) in examples.iter_mut().enumerate() {
214                let color = COLORS[i % COLORS.len()].to_string();
215                example_instance.set_log_prefix_style(&color, max_name_width);
216
217                println!(
218                    "{}Logging to: {}",
219                    example_instance.log_prefix,
220                    example_instance.run_directory.display()
221                );
222
223                let repo_url = example_instance.repo_url();
224                if repo_urls.insert(repo_url.clone()) {
225                    let repo_path = example_instance.repo_path.clone();
226
227                    if !repo_path.join(".git").is_dir() {
228                        println!(
229                            "{:<width$} < {}",
230                            "↓ Cloning",
231                            repo_url,
232                            width = max_name_width
233                        );
234
235                        let git_task = cx.spawn(async move |_cx| {
236                            std::fs::create_dir_all(&repo_path)?;
237                            run_git(&repo_path, &["init"]).await?;
238                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
239                        });
240
241                        clone_tasks.push(git_task);
242                    } else {
243                        println!(
244                            "{:<width$}  < {}",
245                            "✔︎ Already cloned",
246                            repo_url,
247                            width = max_name_width
248                        );
249
250                        let actual_origin =
251                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
252                        if actual_origin != repo_url {
253                            return Err(anyhow!(
254                                "remote origin {} does not match expected origin {}",
255                                actual_origin,
256                                repo_url,
257                            ));
258                        }
259                    }
260                }
261            }
262
263            future::join_all(clone_tasks).await;
264
265            for example_instance in examples.iter_mut() {
266                example_instance.fetch().await?;
267            }
268
269            let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
270            let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
271
272            future::join_all((0..args.concurrency).map(|_| {
273                let app_state = app_state.clone();
274                let model = model.clone();
275                let zed_commit_sha = zed_commit_sha.clone();
276                let zed_branch_name = zed_branch_name.clone();
277                let run_id = run_id.clone();
278                let examples = examples.clone();
279                let results = results_by_example_name.clone();
280                cx.spawn(async move |cx| {
281                    loop {
282                        let Some(mut example) = examples.borrow_mut().pop_front() else {
283                            break;
284                        };
285                        let result = async {
286                            example.setup().await?;
287                            let run_output = cx
288                                .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
289                                .await?;
290                            let judge_output = judge_example(
291                                example.clone(),
292                                model.clone(),
293                                &zed_commit_sha,
294                                &zed_branch_name,
295                                &run_id,
296                                &run_output,
297                                enable_telemetry,
298                                cx,
299                            )
300                            .await;
301                            anyhow::Ok((run_output, judge_output))
302                        }
303                        .await;
304                        results
305                            .borrow_mut()
306                            .entry(example.name.clone())
307                            .or_insert(Vec::new())
308                            .push((example.clone(), result));
309                    }
310                })
311            }))
312            .await;
313
314            print_report(
315                &mut results_by_example_name.borrow_mut(),
316                &mut cumulative_tool_metrics,
317                &run_dir,
318            )?;
319
320            app_state.client.telemetry().flush_events().await;
321
322            cx.update(|cx| cx.quit())
323        })
324        .detach_and_log_err(cx);
325    });
326}
327
328/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
329pub struct AgentAppState {
330    pub languages: Arc<LanguageRegistry>,
331    pub client: Arc<Client>,
332    pub user_store: Entity<UserStore>,
333    pub fs: Arc<dyn fs::Fs>,
334    pub node_runtime: NodeRuntime,
335
336    // Additional fields not present in `workspace::AppState`.
337    pub prompt_builder: Arc<PromptBuilder>,
338}
339
340pub fn init(cx: &mut App) -> Arc<AgentAppState> {
341    release_channel::init(SemanticVersion::default(), cx);
342    gpui_tokio::init(cx);
343
344    let mut settings_store = SettingsStore::new(cx);
345    settings_store
346        .set_default_settings(settings::default_settings().as_ref(), cx)
347        .unwrap();
348    cx.set_global(settings_store);
349    client::init_settings(cx);
350
351    // Set User-Agent so we can download language servers from GitHub
352    let user_agent = format!(
353        "Zed/{} ({}; {})",
354        AppVersion::global(cx),
355        std::env::consts::OS,
356        std::env::consts::ARCH
357    );
358    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
359    let proxy_url = proxy_str
360        .as_ref()
361        .and_then(|input| input.parse().ok())
362        .or_else(read_proxy_from_env);
363    let http = {
364        let _guard = Tokio::handle(cx).enter();
365
366        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
367            .expect("could not start HTTP client")
368    };
369    cx.set_http_client(Arc::new(http));
370
371    Project::init_settings(cx);
372
373    let client = Client::production(cx);
374    cx.set_http_client(client.http_client());
375
376    let git_binary_path = None;
377    let fs = Arc::new(RealFs::new(
378        git_binary_path,
379        cx.background_executor().clone(),
380    ));
381
382    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
383    languages.set_language_server_download_dir(paths::languages_dir().clone());
384    let languages = Arc::new(languages);
385
386    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
387
388    extension::init(cx);
389
390    let (tx, rx) = async_watch::channel(None);
391    cx.observe_global::<SettingsStore>(move |cx| {
392        let settings = &ProjectSettings::get_global(cx).node;
393        let options = NodeBinaryOptions {
394            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
395            allow_binary_download: true,
396            use_paths: settings.path.as_ref().map(|node_path| {
397                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
398                let npm_path = settings
399                    .npm_path
400                    .as_ref()
401                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
402                (
403                    node_path.clone(),
404                    npm_path.unwrap_or_else(|| {
405                        let base_path = PathBuf::new();
406                        node_path.parent().unwrap_or(&base_path).join("npm")
407                    }),
408                )
409            }),
410        };
411        tx.send(Some(options)).log_err();
412    })
413    .detach();
414    let node_runtime = NodeRuntime::new(client.http_client(), rx);
415
416    let extension_host_proxy = ExtensionHostProxy::global(cx);
417
418    language::init(cx);
419    language_extension::init(extension_host_proxy.clone(), languages.clone());
420    language_model::init(client.clone(), cx);
421    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
422    languages::init(languages.clone(), node_runtime.clone(), cx);
423    assistant_tools::init(client.http_client(), cx);
424    context_server::init(cx);
425    prompt_store::init(cx);
426    let stdout_is_a_pty = false;
427    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
428    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
429
430    SettingsStore::update_global(cx, |store, cx| {
431        store.set_user_settings(include_str!("../runner_settings.json"), cx)
432    })
433    .unwrap();
434
435    Arc::new(AgentAppState {
436        languages,
437        client,
438        user_store,
439        fs,
440        node_runtime,
441        prompt_builder,
442    })
443}
444
445pub fn find_model(
446    model_name: &str,
447    model_registry: &LanguageModelRegistry,
448    cx: &App,
449) -> anyhow::Result<Arc<dyn LanguageModel>> {
450    let model = model_registry
451        .available_models(cx)
452        .find(|model| model.id().0 == model_name);
453
454    let Some(model) = model else {
455        return Err(anyhow!(
456            "No language model named {} was available. Available models: {}",
457            model_name,
458            model_registry
459                .available_models(cx)
460                .map(|model| model.id().0.clone())
461                .collect::<Vec<_>>()
462                .join(", ")
463        ));
464    };
465
466    Ok(model)
467}
468
469pub fn commit_sha_for_path(repo_path: &Path) -> String {
470    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
471}
472
473pub fn git_branch_for_path(repo_path: &Path) -> String {
474    match std::env::var("GITHUB_REF_NAME") {
475        Ok(branch) => branch,
476        Err(_) => {
477            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
478                .unwrap_or_else(|_| "unknown".to_string())
479        }
480    }
481}
482
483async fn judge_example(
484    example: ExampleInstance,
485    model: Arc<dyn LanguageModel>,
486    zed_commit_sha: &str,
487    zed_branch_name: &str,
488    run_id: &str,
489    run_output: &RunOutput,
490    enable_telemetry: bool,
491    cx: &AsyncApp,
492) -> JudgeOutput {
493    let judge_output = example.judge(model.clone(), &run_output, cx).await;
494
495    if enable_telemetry {
496        telemetry::event!(
497            "Agent Example Evaluated",
498            zed_commit_sha = zed_commit_sha,
499            zed_branch_name = zed_branch_name,
500            run_id = run_id,
501            example_name = example.name.clone(),
502            example_repetition = example.repetition,
503            diff_evaluation = judge_output.diff.clone(),
504            thread_evaluation = judge_output.thread.clone(),
505            tool_metrics = run_output.tool_metrics,
506            response_count = run_output.response_count,
507            token_usage = run_output.token_usage,
508            model = model.telemetry_id(),
509            model_provider = model.provider_id().to_string(),
510            repository_url = example.repo_url(),
511            repository_revision = example.revision(),
512            diagnostic_summary_before = run_output.diagnostic_summary_before,
513            diagnostic_summary_after = run_output.diagnostic_summary_after,
514            diagnostics_before = run_output.diagnostics_before,
515            diagnostics_after = run_output.diagnostics_after,
516        );
517    }
518
519    judge_output
520}
521
522const HEADER_WIDTH: usize = 65;
523
524fn print_h1(header: &str) {
525    println!("\n\n{:=^HEADER_WIDTH$}", "");
526    println!("{:^HEADER_WIDTH$}", header);
527    println!("{:=^HEADER_WIDTH$}\n", "");
528}
529
530fn print_h2(header: &str) {
531    println!("\n{:-^HEADER_WIDTH$}", "");
532    println!("{:^HEADER_WIDTH$}", header);
533    println!("{:-^HEADER_WIDTH$}\n", "");
534}
535
536fn print_report(
537    results_by_example_name: &mut HashMap<
538        String,
539        Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
540    >,
541    cumulative_tool_metrics: &mut ToolMetrics,
542    run_dir: &Path,
543) -> anyhow::Result<()> {
544    print_h1("EVAL RESULTS");
545
546    let mut diff_scores = Vec::new();
547    let mut thread_scores = Vec::new();
548    let mut programmatic_scores = Vec::new();
549    let mut error_count = 0;
550
551    for (example_name, results) in results_by_example_name.iter_mut() {
552        print_h2(example_name);
553
554        results.sort_unstable_by_key(|(example, _)| example.repetition);
555        let mut example_cumulative_tool_metrics = ToolMetrics::default();
556
557        let mut table_rows = String::new();
558
559        for (example, result) in results.iter() {
560            match result {
561                Err(err) => {
562                    display_error_row(&mut table_rows, example.repetition, err.to_string())?;
563                    error_count += 1;
564                }
565                Ok((run_output, judge_output)) => {
566                    cumulative_tool_metrics.merge(&run_output.tool_metrics);
567                    example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
568
569                    if !run_output.programmatic_assertions.total_count() > 0 {
570                        for assertion in &run_output.programmatic_assertions.ran {
571                            assertions::display_table_row(
572                                &mut table_rows,
573                                example.repetition,
574                                assertion,
575                            )?;
576                        }
577
578                        programmatic_scores
579                            .push(run_output.programmatic_assertions.passed_percentage())
580                    }
581
582                    if !judge_output.diff.is_empty() {
583                        diff_scores.push(judge_output.diff.passed_percentage());
584
585                        for assertion in &judge_output.diff.ran {
586                            assertions::display_table_row(
587                                &mut table_rows,
588                                example.repetition,
589                                assertion,
590                            )?;
591                        }
592                    }
593
594                    if !judge_output.thread.is_empty() {
595                        thread_scores.push(judge_output.thread.passed_percentage());
596
597                        for assertion in &judge_output.thread.ran {
598                            assertions::display_table_row(
599                                &mut table_rows,
600                                example.repetition,
601                                assertion,
602                            )?;
603                        }
604                    }
605                }
606            }
607        }
608
609        if !table_rows.is_empty() {
610            assertions::print_table_header();
611            print!("{}", table_rows);
612
613            assertions::print_table_divider();
614
615            for (example, result) in results.iter() {
616                if let Ok((run_output, judge_output)) = result {
617                    assertions::print_table_round_summary(
618                        &example.repetition.to_string(),
619                        [
620                            &run_output.programmatic_assertions,
621                            &judge_output.diff,
622                            &judge_output.thread,
623                        ]
624                        .into_iter(),
625                    )
626                }
627            }
628
629            assertions::print_table_divider();
630
631            assertions::print_table_round_summary(
632                "avg",
633                results.iter().flat_map(|(_, result)| {
634                    result.iter().flat_map(|(run_output, judge_output)| {
635                        [
636                            &run_output.programmatic_assertions,
637                            &judge_output.diff,
638                            &judge_output.thread,
639                        ]
640                        .into_iter()
641                    })
642                }),
643            );
644
645            assertions::print_table_footer();
646        }
647
648        if !example_cumulative_tool_metrics.is_empty() {
649            println!("{}", &example_cumulative_tool_metrics);
650        }
651    }
652
653    if results_by_example_name.len() > 1 {
654        print_h1("AGGREGATE");
655
656        if error_count > 0 {
657            println!("\n{error_count} examples failed to run!");
658        }
659
660        let programmatic_score_count = programmatic_scores.len();
661        if programmatic_score_count > 0 {
662            let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
663                / (programmatic_score_count as f32))
664                .floor();
665            println!("Average programmatic score: {average_programmatic_score}%");
666        }
667
668        let diff_score_count = diff_scores.len();
669        if diff_score_count > 0 {
670            let average_diff_score =
671                (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
672            println!("Average diff score: {average_diff_score}%");
673        }
674
675        let thread_score_count = thread_scores.len();
676
677        if thread_score_count > 0 {
678            let average_thread_score =
679                (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
680            println!("Average thread score: {average_thread_score}%");
681        }
682
683        println!("");
684
685        print_h2("CUMULATIVE TOOL METRICS");
686        println!("{}", cumulative_tool_metrics);
687    }
688
689    let explorer_output_path = run_dir.join("overview.html");
690    let mut json_paths: Vec<PathBuf> = results_by_example_name
691        .values()
692        .flat_map(|results| {
693            results.iter().map(|(example, _)| {
694                let absolute_path = example.run_directory.join("last.messages.json");
695                pathdiff::diff_paths(&absolute_path, run_dir)
696                    .unwrap_or_else(|| absolute_path.clone())
697            })
698        })
699        .collect::<Vec<_>>();
700    json_paths.sort();
701    if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
702        eprintln!("Failed to generate explorer HTML: {}", err);
703    }
704
705    Ok(())
706}