eval.rs

  1mod assertions;
  2mod example;
  3mod examples;
  4mod explorer;
  5mod ids;
  6mod instance;
  7mod tool_metrics;
  8
  9use assertions::display_error_row;
 10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
 11pub(crate) use tool_metrics::*;
 12
 13use ::fs::RealFs;
 14use anyhow::anyhow;
 15use clap::Parser;
 16use client::{Client, ProxySettings, UserStore};
 17use collections::{HashMap, HashSet};
 18use extension::ExtensionHostProxy;
 19use futures::future;
 20use gpui::http_client::read_proxy_from_env;
 21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 22use gpui_tokio::Tokio;
 23use language::LanguageRegistry;
 24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 25use node_runtime::{NodeBinaryOptions, NodeRuntime};
 26use project::Project;
 27use project::project_settings::ProjectSettings;
 28use prompt_store::PromptBuilder;
 29use release_channel::AppVersion;
 30use reqwest_client::ReqwestClient;
 31use settings::{Settings, SettingsStore};
 32use std::cell::RefCell;
 33use std::collections::VecDeque;
 34use std::env;
 35use std::path::{Path, PathBuf};
 36use std::rc::Rc;
 37use std::sync::{Arc, LazyLock};
 38use util::ResultExt as _;
 39
 40static CARGO_MANIFEST_DIR: LazyLock<PathBuf> =
 41    LazyLock::new(|| PathBuf::from(env!("CARGO_MANIFEST_DIR")));
 42
 43#[derive(Parser, Debug)]
 44#[command(name = "eval", disable_version_flag = true)]
 45struct Args {
 46    /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
 47    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 48    filter: Vec<String>,
 49    /// Model to use (default: "claude-3-7-sonnet-latest")
 50    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 51    model: String,
 52    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 53    languages: Vec<String>,
 54    /// How many times to run each example.
 55    #[arg(long, default_value = "8")]
 56    repetitions: usize,
 57    /// Maximum number of examples to run concurrently.
 58    #[arg(long, default_value = "4")]
 59    concurrency: usize,
 60}
 61
 62fn main() {
 63    dotenv::from_filename(CARGO_MANIFEST_DIR.join(".env")).ok();
 64
 65    env_logger::init();
 66
 67    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 68    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 69    let session_id = uuid::Uuid::new_v4().to_string();
 70    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 71    let run_id = match env::var("GITHUB_RUN_ID") {
 72        Ok(run_id) => format!("github/{}", run_id),
 73        Err(_) => format!("local/{}", run_timestamp),
 74    };
 75
 76    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 77        .parent()
 78        .unwrap()
 79        .parent()
 80        .unwrap()
 81        .canonicalize()
 82        .unwrap();
 83    let eval_crate_dir = root_dir.join("crates").join("eval");
 84    let repos_dir = eval_crate_dir.join("repos");
 85    let worktrees_dir = eval_crate_dir.join("worktrees");
 86    let examples_dir = eval_crate_dir.join("src").join("examples");
 87    let run_dir = eval_crate_dir
 88        .join("runs")
 89        .join(format!("{}", run_timestamp));
 90    std::fs::create_dir_all(&run_dir).unwrap();
 91    std::fs::create_dir_all(&repos_dir).unwrap();
 92    std::fs::create_dir_all(&worktrees_dir).unwrap();
 93    std::fs::create_dir_all(&examples_dir).unwrap();
 94    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 95
 96    let zed_commit_sha = commit_sha_for_path(&root_dir);
 97    let zed_branch_name = git_branch_for_path(&root_dir);
 98    let args = Args::parse();
 99    let languages: HashSet<String> = args.languages.into_iter().collect();
100
101    let http_client = Arc::new(ReqwestClient::new());
102    let app = Application::headless().with_http_client(http_client.clone());
103    let all_threads = examples::all(&examples_dir);
104
105    app.run(move |cx| {
106        let app_state = init(cx);
107
108        let telemetry = app_state.client.telemetry();
109        telemetry.start(system_id, installation_id, session_id, cx);
110
111        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
112            && telemetry.has_checksum_seed();
113        if enable_telemetry {
114            println!("Telemetry enabled");
115            telemetry::event!(
116                "Agent Eval Started",
117                zed_commit_sha = zed_commit_sha,
118                zed_branch_name = zed_branch_name,
119                run_id = run_id,
120            );
121        }
122
123        let mut cumulative_tool_metrics = ToolMetrics::default();
124
125        let model_registry = LanguageModelRegistry::read_global(cx);
126        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
127        let model_provider_id = model.provider_id();
128        let model_provider = model_registry.provider(&model_provider_id).unwrap();
129
130        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
131            registry.set_default_model(
132                Some(ConfiguredModel {
133                    provider: model_provider.clone(),
134                    model: model.clone(),
135                }),
136                cx,
137            );
138        });
139
140        let authenticate_task = model_provider.authenticate(cx);
141
142        cx.spawn(async move |cx| {
143            authenticate_task.await.unwrap();
144
145            let mut examples = Vec::new();
146
147            const COLORS: [&str; 12] = [
148                "\x1b[31m", // Red
149                "\x1b[32m", // Green
150                "\x1b[33m", // Yellow
151                "\x1b[34m", // Blue
152                "\x1b[35m", // Magenta
153                "\x1b[36m", // Cyan
154                "\x1b[91m", // Bright Red
155                "\x1b[92m", // Bright Green
156                "\x1b[93m", // Bright Yellow
157                "\x1b[94m", // Bright Blue
158                "\x1b[95m", // Bright Magenta
159                "\x1b[96m", // Bright Cyan
160            ];
161
162            let mut skipped = Vec::new();
163
164            for thread in all_threads {
165                let meta = thread.meta();
166                if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
167                {
168                    skipped.push(meta.name);
169                    continue;
170                }
171
172                if let Some(language) = meta.language_server {
173                    if !languages.contains(&language.file_extension) {
174                        panic!(
175                            "Eval for {:?} could not be run because no language server was found for extension {:?}",
176                            meta.name,
177                            language.file_extension
178                        );
179                    }
180                }
181
182                // TODO: This creates a worktree per repetition. Ideally these examples should
183                // either be run sequentially on the same worktree, or reuse worktrees when there
184                // are more examples to run than the concurrency limit.
185                for repetition_number in 0..args.repetitions {
186                    let example_instance = ExampleInstance::new(
187                        thread.clone(),
188                        &repos_dir,
189                        &run_dir,
190                        &worktrees_dir,
191                        repetition_number,
192                    );
193
194                    examples.push(example_instance);
195                }
196            }
197
198            if !skipped.is_empty() {
199                println!("Skipped threads: {}", skipped.join(", "));
200            }
201
202            if examples.is_empty() {
203                eprintln!("Filter matched no examples");
204                return cx.update(|cx| cx.quit());
205            }
206
207            let mut repo_urls = HashSet::default();
208            let mut clone_tasks = Vec::new();
209
210            let max_name_width = examples
211                .iter()
212                .map(|e| e.worktree_name().len())
213                .max()
214                .unwrap_or(0);
215
216            for (i, example_instance) in examples.iter_mut().enumerate() {
217                let color = COLORS[i % COLORS.len()].to_string();
218                example_instance.set_log_prefix_style(&color, max_name_width);
219
220                println!(
221                    "{}Logging to: {}",
222                    example_instance.log_prefix,
223                    example_instance.run_directory.display()
224                );
225
226                let repo_url = example_instance.repo_url();
227                if repo_urls.insert(repo_url.clone()) {
228                    let repo_path = example_instance.repo_path.clone();
229
230                    if !repo_path.join(".git").is_dir() {
231                        println!(
232                            "{:<width$} < {}",
233                            "↓ Cloning",
234                            repo_url,
235                            width = max_name_width
236                        );
237
238                        let git_task = cx.spawn(async move |_cx| {
239                            std::fs::create_dir_all(&repo_path)?;
240                            run_git(&repo_path, &["init"]).await?;
241                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
242                        });
243
244                        clone_tasks.push(git_task);
245                    } else {
246                        println!(
247                            "{:<width$}  < {}",
248                            "✔︎ Already cloned",
249                            repo_url,
250                            width = max_name_width
251                        );
252
253                        let actual_origin =
254                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
255                        if actual_origin != repo_url {
256                            return Err(anyhow!(
257                                "remote origin {} does not match expected origin {}",
258                                actual_origin,
259                                repo_url,
260                            ));
261                        }
262                    }
263                }
264            }
265
266            future::join_all(clone_tasks).await;
267
268            for example_instance in examples.iter_mut() {
269                example_instance.fetch().await?;
270            }
271
272            let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
273            let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
274
275            future::join_all((0..args.concurrency).map(|_| {
276                let app_state = app_state.clone();
277                let model = model.clone();
278                let zed_commit_sha = zed_commit_sha.clone();
279                let zed_branch_name = zed_branch_name.clone();
280                let run_id = run_id.clone();
281                let examples = examples.clone();
282                let results = results_by_example_name.clone();
283                cx.spawn(async move |cx| {
284                    loop {
285                        let Some(mut example) = examples.borrow_mut().pop_front() else {
286                            break;
287                        };
288                        let result = async {
289                            example.setup().await?;
290                            let run_output = cx
291                                .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
292                                .await?;
293                            let judge_output = judge_example(
294                                example.clone(),
295                                model.clone(),
296                                &zed_commit_sha,
297                                &zed_branch_name,
298                                &run_id,
299                                &run_output,
300                                enable_telemetry,
301                                cx,
302                            )
303                            .await;
304                            anyhow::Ok((run_output, judge_output))
305                        }
306                        .await;
307                        results
308                            .borrow_mut()
309                            .entry(example.name.clone())
310                            .or_insert(Vec::new())
311                            .push((example.clone(), result));
312                    }
313                })
314            }))
315            .await;
316
317            print_report(
318                &mut results_by_example_name.borrow_mut(),
319                &mut cumulative_tool_metrics,
320                &run_dir,
321            )?;
322
323            app_state.client.telemetry().flush_events().await;
324
325            cx.update(|cx| cx.quit())
326        })
327        .detach_and_log_err(cx);
328    });
329}
330
331/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
332pub struct AgentAppState {
333    pub languages: Arc<LanguageRegistry>,
334    pub client: Arc<Client>,
335    pub user_store: Entity<UserStore>,
336    pub fs: Arc<dyn fs::Fs>,
337    pub node_runtime: NodeRuntime,
338
339    // Additional fields not present in `workspace::AppState`.
340    pub prompt_builder: Arc<PromptBuilder>,
341}
342
343pub fn init(cx: &mut App) -> Arc<AgentAppState> {
344    release_channel::init(SemanticVersion::default(), cx);
345    gpui_tokio::init(cx);
346
347    let mut settings_store = SettingsStore::new(cx);
348    settings_store
349        .set_default_settings(settings::default_settings().as_ref(), cx)
350        .unwrap();
351    cx.set_global(settings_store);
352    client::init_settings(cx);
353
354    // Set User-Agent so we can download language servers from GitHub
355    let user_agent = format!(
356        "Zed/{} ({}; {})",
357        AppVersion::global(cx),
358        std::env::consts::OS,
359        std::env::consts::ARCH
360    );
361    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
362    let proxy_url = proxy_str
363        .as_ref()
364        .and_then(|input| input.parse().ok())
365        .or_else(read_proxy_from_env);
366    let http = {
367        let _guard = Tokio::handle(cx).enter();
368
369        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
370            .expect("could not start HTTP client")
371    };
372    cx.set_http_client(Arc::new(http));
373
374    Project::init_settings(cx);
375
376    let client = Client::production(cx);
377    cx.set_http_client(client.http_client());
378
379    let git_binary_path = None;
380    let fs = Arc::new(RealFs::new(
381        git_binary_path,
382        cx.background_executor().clone(),
383    ));
384
385    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
386    languages.set_language_server_download_dir(paths::languages_dir().clone());
387    let languages = Arc::new(languages);
388
389    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
390
391    extension::init(cx);
392
393    let (tx, rx) = async_watch::channel(None);
394    cx.observe_global::<SettingsStore>(move |cx| {
395        let settings = &ProjectSettings::get_global(cx).node;
396        let options = NodeBinaryOptions {
397            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
398            allow_binary_download: true,
399            use_paths: settings.path.as_ref().map(|node_path| {
400                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
401                let npm_path = settings
402                    .npm_path
403                    .as_ref()
404                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
405                (
406                    node_path.clone(),
407                    npm_path.unwrap_or_else(|| {
408                        let base_path = PathBuf::new();
409                        node_path.parent().unwrap_or(&base_path).join("npm")
410                    }),
411                )
412            }),
413        };
414        tx.send(Some(options)).log_err();
415    })
416    .detach();
417    let node_runtime = NodeRuntime::new(client.http_client(), rx);
418
419    let extension_host_proxy = ExtensionHostProxy::global(cx);
420
421    language::init(cx);
422    language_extension::init(extension_host_proxy.clone(), languages.clone());
423    language_model::init(client.clone(), cx);
424    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
425    languages::init(languages.clone(), node_runtime.clone(), cx);
426    context_server::init(cx);
427    prompt_store::init(cx);
428    let stdout_is_a_pty = false;
429    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
430    agent::init(
431        fs.clone(),
432        client.clone(),
433        prompt_builder.clone(),
434        languages.clone(),
435        cx,
436    );
437    assistant_tools::init(client.http_client(), cx);
438
439    SettingsStore::update_global(cx, |store, cx| {
440        store.set_user_settings(include_str!("../runner_settings.json"), cx)
441    })
442    .unwrap();
443
444    Arc::new(AgentAppState {
445        languages,
446        client,
447        user_store,
448        fs,
449        node_runtime,
450        prompt_builder,
451    })
452}
453
454pub fn find_model(
455    model_name: &str,
456    model_registry: &LanguageModelRegistry,
457    cx: &App,
458) -> anyhow::Result<Arc<dyn LanguageModel>> {
459    let model = model_registry
460        .available_models(cx)
461        .find(|model| model.id().0 == model_name);
462
463    let Some(model) = model else {
464        return Err(anyhow!(
465            "No language model named {} was available. Available models: {}",
466            model_name,
467            model_registry
468                .available_models(cx)
469                .map(|model| model.id().0.clone())
470                .collect::<Vec<_>>()
471                .join(", ")
472        ));
473    };
474
475    Ok(model)
476}
477
478pub fn commit_sha_for_path(repo_path: &Path) -> String {
479    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
480}
481
482pub fn git_branch_for_path(repo_path: &Path) -> String {
483    match std::env::var("GITHUB_REF_NAME") {
484        Ok(branch) => branch,
485        Err(_) => {
486            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
487                .unwrap_or_else(|_| "unknown".to_string())
488        }
489    }
490}
491
492async fn judge_example(
493    example: ExampleInstance,
494    model: Arc<dyn LanguageModel>,
495    zed_commit_sha: &str,
496    zed_branch_name: &str,
497    run_id: &str,
498    run_output: &RunOutput,
499    enable_telemetry: bool,
500    cx: &AsyncApp,
501) -> JudgeOutput {
502    let judge_output = example.judge(model.clone(), &run_output, cx).await;
503
504    if enable_telemetry {
505        telemetry::event!(
506            "Agent Example Evaluated",
507            zed_commit_sha = zed_commit_sha,
508            zed_branch_name = zed_branch_name,
509            run_id = run_id,
510            example_name = example.name.clone(),
511            example_repetition = example.repetition,
512            diff_evaluation = judge_output.diff.clone(),
513            thread_evaluation = judge_output.thread.clone(),
514            tool_metrics = run_output.tool_metrics,
515            response_count = run_output.response_count,
516            token_usage = run_output.token_usage,
517            model = model.telemetry_id(),
518            model_provider = model.provider_id().to_string(),
519            repository_url = example.repo_url(),
520            repository_revision = example.revision(),
521            diagnostic_summary_before = run_output.diagnostic_summary_before,
522            diagnostic_summary_after = run_output.diagnostic_summary_after,
523            diagnostics_before = run_output.diagnostics_before,
524            diagnostics_after = run_output.diagnostics_after,
525        );
526    }
527
528    judge_output
529}
530
531const HEADER_WIDTH: usize = 65;
532
533fn print_h1(header: &str) {
534    println!("\n\n{:=^HEADER_WIDTH$}", "");
535    println!("{:^HEADER_WIDTH$}", header);
536    println!("{:=^HEADER_WIDTH$}\n", "");
537}
538
539fn print_h2(header: &str) {
540    println!("\n{:-^HEADER_WIDTH$}", "");
541    println!("{:^HEADER_WIDTH$}", header);
542    println!("{:-^HEADER_WIDTH$}\n", "");
543}
544
545fn print_report(
546    results_by_example_name: &mut HashMap<
547        String,
548        Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
549    >,
550    cumulative_tool_metrics: &mut ToolMetrics,
551    run_dir: &Path,
552) -> anyhow::Result<()> {
553    print_h1("EVAL RESULTS");
554
555    let mut diff_scores = Vec::new();
556    let mut thread_scores = Vec::new();
557    let mut programmatic_scores = Vec::new();
558    let mut error_count = 0;
559
560    for (example_name, results) in results_by_example_name.iter_mut() {
561        print_h2(example_name);
562
563        results.sort_unstable_by_key(|(example, _)| example.repetition);
564        let mut example_cumulative_tool_metrics = ToolMetrics::default();
565
566        let mut table_rows = String::new();
567
568        for (example, result) in results.iter() {
569            match result {
570                Err(err) => {
571                    display_error_row(&mut table_rows, example.repetition, err.to_string())?;
572                    error_count += 1;
573                }
574                Ok((run_output, judge_output)) => {
575                    cumulative_tool_metrics.merge(&run_output.tool_metrics);
576                    example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
577
578                    if !run_output.programmatic_assertions.total_count() > 0 {
579                        for assertion in &run_output.programmatic_assertions.ran {
580                            assertions::display_table_row(
581                                &mut table_rows,
582                                example.repetition,
583                                assertion,
584                            )?;
585                        }
586
587                        programmatic_scores
588                            .push(run_output.programmatic_assertions.passed_percentage())
589                    }
590
591                    if !judge_output.diff.is_empty() {
592                        diff_scores.push(judge_output.diff.passed_percentage());
593
594                        for assertion in &judge_output.diff.ran {
595                            assertions::display_table_row(
596                                &mut table_rows,
597                                example.repetition,
598                                assertion,
599                            )?;
600                        }
601                    }
602
603                    if !judge_output.thread.is_empty() {
604                        thread_scores.push(judge_output.thread.passed_percentage());
605
606                        for assertion in &judge_output.thread.ran {
607                            assertions::display_table_row(
608                                &mut table_rows,
609                                example.repetition,
610                                assertion,
611                            )?;
612                        }
613                    }
614                }
615            }
616        }
617
618        if !table_rows.is_empty() {
619            assertions::print_table_header();
620            print!("{}", table_rows);
621
622            assertions::print_table_divider();
623
624            for (example, result) in results.iter() {
625                if let Ok((run_output, judge_output)) = result {
626                    assertions::print_table_round_summary(
627                        &example.repetition.to_string(),
628                        [
629                            &run_output.programmatic_assertions,
630                            &judge_output.diff,
631                            &judge_output.thread,
632                        ]
633                        .into_iter(),
634                    )
635                }
636            }
637
638            assertions::print_table_divider();
639
640            assertions::print_table_round_summary(
641                "avg",
642                results.iter().flat_map(|(_, result)| {
643                    result.iter().flat_map(|(run_output, judge_output)| {
644                        [
645                            &run_output.programmatic_assertions,
646                            &judge_output.diff,
647                            &judge_output.thread,
648                        ]
649                        .into_iter()
650                    })
651                }),
652            );
653
654            assertions::print_table_footer();
655        }
656
657        if !example_cumulative_tool_metrics.is_empty() {
658            println!("{}", &example_cumulative_tool_metrics);
659        }
660    }
661
662    if results_by_example_name.len() > 1 {
663        print_h1("AGGREGATE");
664
665        if error_count > 0 {
666            println!("\n{error_count} examples failed to run!");
667        }
668
669        let programmatic_score_count = programmatic_scores.len();
670        if programmatic_score_count > 0 {
671            let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
672                / (programmatic_score_count as f32))
673                .floor();
674            println!("Average programmatic score: {average_programmatic_score}%");
675        }
676
677        let diff_score_count = diff_scores.len();
678        if diff_score_count > 0 {
679            let average_diff_score =
680                (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
681            println!("Average diff score: {average_diff_score}%");
682        }
683
684        let thread_score_count = thread_scores.len();
685
686        if thread_score_count > 0 {
687            let average_thread_score =
688                (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
689            println!("Average thread score: {average_thread_score}%");
690        }
691
692        println!("");
693
694        print_h2("CUMULATIVE TOOL METRICS");
695        println!("{}", cumulative_tool_metrics);
696    }
697
698    let explorer_output_path = run_dir.join("overview.html");
699    let mut json_paths: Vec<PathBuf> = results_by_example_name
700        .values()
701        .flat_map(|results| {
702            results.iter().map(|(example, _)| {
703                let absolute_path = example.run_directory.join("last.messages.json");
704                pathdiff::diff_paths(&absolute_path, run_dir)
705                    .unwrap_or_else(|| absolute_path.clone())
706            })
707        })
708        .collect::<Vec<_>>();
709    json_paths.sort();
710    if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
711        eprintln!("Failed to generate explorer HTML: {}", err);
712    }
713
714    Ok(())
715}