eval.rs

  1mod assertions;
  2mod example;
  3mod examples;
  4mod explorer;
  5mod ids;
  6mod instance;
  7mod tool_metrics;
  8
  9use assertions::display_error_row;
 10use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
 11pub(crate) use tool_metrics::*;
 12
 13use ::fs::RealFs;
 14use anyhow::anyhow;
 15use clap::Parser;
 16use client::{Client, ProxySettings, UserStore};
 17use collections::{HashMap, HashSet};
 18use extension::ExtensionHostProxy;
 19use futures::future;
 20use gpui::http_client::{Uri, read_proxy_from_env};
 21use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 22use gpui_tokio::Tokio;
 23use language::LanguageRegistry;
 24use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 25use node_runtime::{NodeBinaryOptions, NodeRuntime};
 26use project::Project;
 27use project::project_settings::ProjectSettings;
 28use prompt_store::PromptBuilder;
 29use release_channel::AppVersion;
 30use reqwest_client::ReqwestClient;
 31use settings::{Settings, SettingsStore};
 32use std::cell::RefCell;
 33use std::collections::VecDeque;
 34use std::env;
 35use std::path::{Path, PathBuf};
 36use std::rc::Rc;
 37use std::sync::Arc;
 38use util::ResultExt as _;
 39
 40#[derive(Parser, Debug)]
 41#[command(name = "eval", disable_version_flag = true)]
 42struct Args {
 43    /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
 44    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 45    filter: Vec<String>,
 46    /// Model to use (default: "claude-3-7-sonnet-latest")
 47    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 48    model: String,
 49    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 50    languages: Vec<String>,
 51    /// How many times to run each example.
 52    #[arg(long, default_value = "1")]
 53    repetitions: usize,
 54    /// Maximum number of examples to run concurrently.
 55    #[arg(long, default_value = "10")]
 56    concurrency: usize,
 57}
 58
 59fn main() {
 60    env_logger::init();
 61
 62    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 63    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 64    let session_id = uuid::Uuid::new_v4().to_string();
 65    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 66    let run_id = match env::var("GITHUB_RUN_ID") {
 67        Ok(run_id) => format!("github/{}", run_id),
 68        Err(_) => format!("local/{}", run_timestamp),
 69    };
 70
 71    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 72        .parent()
 73        .unwrap()
 74        .parent()
 75        .unwrap()
 76        .canonicalize()
 77        .unwrap();
 78    let eval_crate_dir = root_dir.join("crates").join("eval");
 79    let repos_dir = eval_crate_dir.join("repos");
 80    let worktrees_dir = eval_crate_dir.join("worktrees");
 81    let examples_dir = eval_crate_dir.join("src").join("examples");
 82    let run_dir = eval_crate_dir
 83        .join("runs")
 84        .join(format!("{}", run_timestamp));
 85    std::fs::create_dir_all(&run_dir).unwrap();
 86    std::fs::create_dir_all(&repos_dir).unwrap();
 87    std::fs::create_dir_all(&worktrees_dir).unwrap();
 88    std::fs::create_dir_all(&examples_dir).unwrap();
 89    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 90
 91    let zed_commit_sha = commit_sha_for_path(&root_dir);
 92    let zed_branch_name = git_branch_for_path(&root_dir);
 93    let args = Args::parse();
 94    let languages: HashSet<String> = args.languages.into_iter().collect();
 95
 96    let http_client = Arc::new(ReqwestClient::new());
 97    let app = Application::headless().with_http_client(http_client.clone());
 98    let all_threads = examples::all(&examples_dir);
 99
100    app.run(move |cx| {
101        let app_state = init(cx);
102
103        let telemetry = app_state.client.telemetry();
104        telemetry.start(system_id, installation_id, session_id, cx);
105
106        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
107            && telemetry.has_checksum_seed();
108        if enable_telemetry {
109            println!("Telemetry enabled");
110            telemetry::event!(
111                "Agent Eval Started",
112                zed_commit_sha = zed_commit_sha,
113                zed_branch_name = zed_branch_name,
114                run_id = run_id,
115            );
116        }
117
118        let mut cumulative_tool_metrics = ToolMetrics::default();
119
120        let model_registry = LanguageModelRegistry::read_global(cx);
121        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
122        let model_provider_id = model.provider_id();
123        let model_provider = model_registry.provider(&model_provider_id).unwrap();
124
125        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
126            registry.set_default_model(
127                Some(ConfiguredModel {
128                    provider: model_provider.clone(),
129                    model: model.clone(),
130                }),
131                cx,
132            );
133        });
134
135        let authenticate_task = model_provider.authenticate(cx);
136
137        cx.spawn(async move |cx| {
138            authenticate_task.await.unwrap();
139
140            let mut examples = Vec::new();
141
142            const COLORS: [&str; 12] = [
143                "\x1b[31m", // Red
144                "\x1b[32m", // Green
145                "\x1b[33m", // Yellow
146                "\x1b[34m", // Blue
147                "\x1b[35m", // Magenta
148                "\x1b[36m", // Cyan
149                "\x1b[91m", // Bright Red
150                "\x1b[92m", // Bright Green
151                "\x1b[93m", // Bright Yellow
152                "\x1b[94m", // Bright Blue
153                "\x1b[95m", // Bright Magenta
154                "\x1b[96m", // Bright Cyan
155            ];
156
157            let mut skipped = Vec::new();
158
159            for thread in all_threads {
160                let meta = thread.meta();
161                if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
162                {
163                    skipped.push(meta.name);
164                    continue;
165                }
166
167                if meta.language_server.map_or(false, |language| {
168                    !languages.contains(&language.file_extension)
169                }) {
170                    skipped.push(meta.name);
171                    continue;
172                }
173
174                // TODO: This creates a worktree per repetition. Ideally these examples should
175                // either be run sequentially on the same worktree, or reuse worktrees when there
176                // are more examples to run than the concurrency limit.
177                for repetition_number in 0..args.repetitions {
178                    let example_instance = ExampleInstance::new(
179                        thread.clone(),
180                        &repos_dir,
181                        &run_dir,
182                        &worktrees_dir,
183                        repetition_number,
184                    );
185
186                    examples.push(example_instance);
187                }
188            }
189
190            if !skipped.is_empty() {
191                println!("Skipped threads: {}", skipped.join(", "));
192            }
193
194            if examples.is_empty() {
195                eprintln!("Filter matched no examples");
196                return cx.update(|cx| cx.quit());
197            }
198
199            let mut repo_urls = HashSet::default();
200            let mut clone_tasks = Vec::new();
201
202            let max_name_width = examples
203                .iter()
204                .map(|e| e.worktree_name().len())
205                .max()
206                .unwrap_or(0);
207
208            for (i, example_instance) in examples.iter_mut().enumerate() {
209                let color = COLORS[i % COLORS.len()].to_string();
210                example_instance.set_log_prefix_style(&color, max_name_width);
211
212                println!(
213                    "{}Logging to: {}",
214                    example_instance.log_prefix,
215                    example_instance.run_directory.display()
216                );
217
218                let repo_url = example_instance.repo_url();
219                if repo_urls.insert(repo_url.clone()) {
220                    let repo_path = example_instance.repo_path.clone();
221
222                    if !repo_path.join(".git").is_dir() {
223                        println!(
224                            "{:<width$} < {}",
225                            "↓ Cloning",
226                            repo_url,
227                            width = max_name_width
228                        );
229
230                        let git_task = cx.spawn(async move |_cx| {
231                            std::fs::create_dir_all(&repo_path)?;
232                            run_git(&repo_path, &["init"]).await?;
233                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
234                        });
235
236                        clone_tasks.push(git_task);
237                    } else {
238                        println!(
239                            "{:<width$}  < {}",
240                            "✔︎ Already cloned",
241                            repo_url,
242                            width = max_name_width
243                        );
244
245                        let actual_origin =
246                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
247                        if actual_origin != repo_url {
248                            return Err(anyhow!(
249                                "remote origin {} does not match expected origin {}",
250                                actual_origin,
251                                repo_url,
252                            ));
253                        }
254                    }
255                }
256            }
257
258            future::join_all(clone_tasks).await;
259
260            for example_instance in examples.iter_mut() {
261                example_instance.fetch().await?;
262            }
263
264            let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
265            let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
266
267            future::join_all((0..args.concurrency).map(|_| {
268                let app_state = app_state.clone();
269                let model = model.clone();
270                let zed_commit_sha = zed_commit_sha.clone();
271                let zed_branch_name = zed_branch_name.clone();
272                let run_id = run_id.clone();
273                let examples = examples.clone();
274                let results = results_by_example_name.clone();
275                cx.spawn(async move |cx| {
276                    loop {
277                        let Some(mut example) = examples.borrow_mut().pop_front() else {
278                            break;
279                        };
280                        let result = async {
281                            example.setup().await?;
282                            let run_output = cx
283                                .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
284                                .await?;
285                            let judge_output = judge_example(
286                                example.clone(),
287                                model.clone(),
288                                &zed_commit_sha,
289                                &zed_branch_name,
290                                &run_id,
291                                &run_output,
292                                enable_telemetry,
293                                cx,
294                            )
295                            .await;
296                            anyhow::Ok((run_output, judge_output))
297                        }
298                        .await;
299                        results
300                            .borrow_mut()
301                            .entry(example.name.clone())
302                            .or_insert(Vec::new())
303                            .push((example.clone(), result));
304                    }
305                })
306            }))
307            .await;
308
309            print_report(
310                &mut results_by_example_name.borrow_mut(),
311                &mut cumulative_tool_metrics,
312                &run_dir,
313            )?;
314
315            app_state.client.telemetry().flush_events().await;
316
317            cx.update(|cx| cx.quit())
318        })
319        .detach_and_log_err(cx);
320    });
321}
322
323/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
324pub struct AgentAppState {
325    pub languages: Arc<LanguageRegistry>,
326    pub client: Arc<Client>,
327    pub user_store: Entity<UserStore>,
328    pub fs: Arc<dyn fs::Fs>,
329    pub node_runtime: NodeRuntime,
330
331    // Additional fields not present in `workspace::AppState`.
332    pub prompt_builder: Arc<PromptBuilder>,
333}
334
335pub fn init(cx: &mut App) -> Arc<AgentAppState> {
336    release_channel::init(SemanticVersion::default(), cx);
337    gpui_tokio::init(cx);
338
339    let mut settings_store = SettingsStore::new(cx);
340    settings_store
341        .set_default_settings(settings::default_settings().as_ref(), cx)
342        .unwrap();
343    cx.set_global(settings_store);
344    client::init_settings(cx);
345
346    // Set User-Agent so we can download language servers from GitHub
347    let user_agent = format!(
348        "Zed/{} ({}; {})",
349        AppVersion::global(cx),
350        std::env::consts::OS,
351        std::env::consts::ARCH
352    );
353    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
354    let proxy_url = proxy_str
355        .as_ref()
356        .and_then(|input| input.parse::<Uri>().ok())
357        .or_else(read_proxy_from_env);
358    let http = {
359        let _guard = Tokio::handle(cx).enter();
360
361        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
362            .expect("could not start HTTP client")
363    };
364    cx.set_http_client(Arc::new(http));
365
366    Project::init_settings(cx);
367
368    let client = Client::production(cx);
369    cx.set_http_client(client.http_client().clone());
370
371    let git_binary_path = None;
372    let fs = Arc::new(RealFs::new(
373        git_binary_path,
374        cx.background_executor().clone(),
375    ));
376
377    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
378    languages.set_language_server_download_dir(paths::languages_dir().clone());
379    let languages = Arc::new(languages);
380
381    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
382
383    extension::init(cx);
384
385    let (tx, rx) = async_watch::channel(None);
386    cx.observe_global::<SettingsStore>(move |cx| {
387        let settings = &ProjectSettings::get_global(cx).node;
388        let options = NodeBinaryOptions {
389            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
390            allow_binary_download: true,
391            use_paths: settings.path.as_ref().map(|node_path| {
392                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
393                let npm_path = settings
394                    .npm_path
395                    .as_ref()
396                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
397                (
398                    node_path.clone(),
399                    npm_path.unwrap_or_else(|| {
400                        let base_path = PathBuf::new();
401                        node_path.parent().unwrap_or(&base_path).join("npm")
402                    }),
403                )
404            }),
405        };
406        tx.send(Some(options)).log_err();
407    })
408    .detach();
409    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
410
411    let extension_host_proxy = ExtensionHostProxy::global(cx);
412
413    language::init(cx);
414    language_extension::init(extension_host_proxy.clone(), languages.clone());
415    language_model::init(client.clone(), cx);
416    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
417    languages::init(languages.clone(), node_runtime.clone(), cx);
418    assistant_tools::init(client.http_client().clone(), cx);
419    context_server::init(cx);
420    prompt_store::init(cx);
421    let stdout_is_a_pty = false;
422    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
423    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
424
425    SettingsStore::update_global(cx, |store, cx| {
426        store.set_user_settings(include_str!("../runner_settings.json"), cx)
427    })
428    .unwrap();
429
430    Arc::new(AgentAppState {
431        languages,
432        client,
433        user_store,
434        fs,
435        node_runtime,
436        prompt_builder,
437    })
438}
439
440pub fn find_model(
441    model_name: &str,
442    model_registry: &LanguageModelRegistry,
443    cx: &App,
444) -> anyhow::Result<Arc<dyn LanguageModel>> {
445    let model = model_registry
446        .available_models(cx)
447        .find(|model| model.id().0 == model_name);
448
449    let Some(model) = model else {
450        return Err(anyhow!(
451            "No language model named {} was available. Available models: {}",
452            model_name,
453            model_registry
454                .available_models(cx)
455                .map(|model| model.id().0.clone())
456                .collect::<Vec<_>>()
457                .join(", ")
458        ));
459    };
460
461    Ok(model)
462}
463
464pub fn commit_sha_for_path(repo_path: &Path) -> String {
465    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
466}
467
468pub fn git_branch_for_path(repo_path: &Path) -> String {
469    match std::env::var("GITHUB_REF_NAME") {
470        Ok(branch) => branch,
471        Err(_) => {
472            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
473                .unwrap_or_else(|_| "unknown".to_string())
474        }
475    }
476}
477
478async fn judge_example(
479    example: ExampleInstance,
480    model: Arc<dyn LanguageModel>,
481    zed_commit_sha: &str,
482    zed_branch_name: &str,
483    run_id: &str,
484    run_output: &RunOutput,
485    enable_telemetry: bool,
486    cx: &AsyncApp,
487) -> JudgeOutput {
488    let judge_output = example.judge(model.clone(), &run_output, cx).await;
489
490    if enable_telemetry {
491        telemetry::event!(
492            "Agent Example Evaluated",
493            zed_commit_sha = zed_commit_sha,
494            zed_branch_name = zed_branch_name,
495            run_id = run_id,
496            example_name = example.name.clone(),
497            example_repetition = example.repetition,
498            diff_evaluation = judge_output.diff.clone(),
499            thread_evaluation = judge_output.thread.clone(),
500            tool_metrics = run_output.tool_metrics,
501            response_count = run_output.response_count,
502            token_usage = run_output.token_usage,
503            model = model.telemetry_id(),
504            model_provider = model.provider_id().to_string(),
505            repository_url = example.repo_url(),
506            repository_revision = example.revision(),
507            diagnostic_summary_before = run_output.diagnostic_summary_before,
508            diagnostic_summary_after = run_output.diagnostic_summary_after,
509            diagnostics_before = run_output.diagnostics_before,
510            diagnostics_after = run_output.diagnostics_after,
511        );
512    }
513
514    judge_output
515}
516
517const HEADER_WIDTH: usize = 65;
518
519fn print_h1(header: &str) {
520    println!("\n\n{:=^HEADER_WIDTH$}", "");
521    println!("{:^HEADER_WIDTH$}", header);
522    println!("{:=^HEADER_WIDTH$}\n", "");
523}
524
525fn print_h2(header: &str) {
526    println!("\n{:-^HEADER_WIDTH$}", "");
527    println!("{:^HEADER_WIDTH$}", header);
528    println!("{:-^HEADER_WIDTH$}\n", "");
529}
530
531fn print_report(
532    results_by_example_name: &mut HashMap<
533        String,
534        Vec<(ExampleInstance, anyhow::Result<(RunOutput, JudgeOutput)>)>,
535    >,
536    cumulative_tool_metrics: &mut ToolMetrics,
537    run_dir: &Path,
538) -> anyhow::Result<()> {
539    print_h1("EVAL RESULTS");
540
541    let mut diff_scores = Vec::new();
542    let mut thread_scores = Vec::new();
543    let mut programmatic_scores = Vec::new();
544    let mut error_count = 0;
545
546    for (example_name, results) in results_by_example_name.iter_mut() {
547        print_h2(example_name);
548
549        results.sort_unstable_by_key(|(example, _)| example.repetition);
550        let mut example_cumulative_tool_metrics = ToolMetrics::default();
551
552        let mut table_rows = String::new();
553
554        for (example, result) in results.iter() {
555            match result {
556                Err(err) => {
557                    display_error_row(&mut table_rows, example.repetition, err.to_string())?;
558                    error_count += 1;
559                }
560                Ok((run_output, judge_output)) => {
561                    cumulative_tool_metrics.merge(&run_output.tool_metrics);
562                    example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
563
564                    if !run_output.programmatic_assertions.total_count() > 0 {
565                        for assertion in &run_output.programmatic_assertions.ran {
566                            assertions::display_table_row(
567                                &mut table_rows,
568                                example.repetition,
569                                assertion,
570                            )?;
571                        }
572
573                        programmatic_scores
574                            .push(run_output.programmatic_assertions.passed_percentage())
575                    }
576
577                    if !judge_output.diff.is_empty() {
578                        diff_scores.push(judge_output.diff.passed_percentage());
579
580                        for assertion in &judge_output.diff.ran {
581                            assertions::display_table_row(
582                                &mut table_rows,
583                                example.repetition,
584                                assertion,
585                            )?;
586                        }
587                    }
588
589                    if !judge_output.thread.is_empty() {
590                        thread_scores.push(judge_output.thread.passed_percentage());
591
592                        for assertion in &judge_output.thread.ran {
593                            assertions::display_table_row(
594                                &mut table_rows,
595                                example.repetition,
596                                assertion,
597                            )?;
598                        }
599                    }
600                }
601            }
602        }
603
604        if !table_rows.is_empty() {
605            assertions::print_table_header();
606            print!("{}", table_rows);
607
608            assertions::print_table_divider();
609
610            for (example, result) in results.iter() {
611                if let Ok((run_output, judge_output)) = result {
612                    assertions::print_table_round_summary(
613                        &example.repetition.to_string(),
614                        [
615                            &run_output.programmatic_assertions,
616                            &judge_output.diff,
617                            &judge_output.thread,
618                        ]
619                        .into_iter(),
620                    )
621                }
622            }
623
624            assertions::print_table_divider();
625
626            assertions::print_table_round_summary(
627                "avg",
628                results.iter().flat_map(|(_, result)| {
629                    result.iter().flat_map(|(run_output, judge_output)| {
630                        [
631                            &run_output.programmatic_assertions,
632                            &judge_output.diff,
633                            &judge_output.thread,
634                        ]
635                        .into_iter()
636                    })
637                }),
638            );
639
640            assertions::print_table_footer();
641        }
642
643        if !example_cumulative_tool_metrics.is_empty() {
644            println!("{}", &example_cumulative_tool_metrics);
645        }
646    }
647
648    if results_by_example_name.len() > 1 {
649        print_h1("AGGREGATE");
650
651        if error_count > 0 {
652            println!("\n{error_count} examples failed to run!");
653        }
654
655        let programmatic_score_count = programmatic_scores.len();
656        if programmatic_score_count > 0 {
657            let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
658                / (programmatic_score_count as f32))
659                .floor();
660            println!("Average programmatic score: {average_programmatic_score}%");
661        }
662
663        let diff_score_count = diff_scores.len();
664        if diff_score_count > 0 {
665            let average_diff_score =
666                (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
667            println!("Average diff score: {average_diff_score}%");
668        }
669
670        let thread_score_count = thread_scores.len();
671
672        if thread_score_count > 0 {
673            let average_thread_score =
674                (thread_scores.into_iter().sum::<f32>() / (thread_score_count as f32)).floor();
675            println!("Average thread score: {average_thread_score}%");
676        }
677
678        println!("");
679
680        print_h2("CUMULATIVE TOOL METRICS");
681        println!("{}", cumulative_tool_metrics);
682    }
683
684    let explorer_output_path = run_dir.join("overview.html");
685    let mut json_paths: Vec<PathBuf> = results_by_example_name
686        .values()
687        .flat_map(|results| {
688            results.iter().map(|(example, _)| {
689                let absolute_path = example.run_directory.join("last.messages.json");
690                pathdiff::diff_paths(&absolute_path, run_dir)
691                    .unwrap_or_else(|| absolute_path.clone())
692            })
693        })
694        .collect::<Vec<_>>();
695    json_paths.sort();
696    if let Err(err) = explorer::generate_explorer_html(&json_paths, &explorer_output_path) {
697        eprintln!("Failed to generate explorer HTML: {}", err);
698    }
699
700    Ok(())
701}