eval.rs

  1mod assertions;
  2mod example;
  3mod examples;
  4mod ids;
  5mod instance;
  6mod tool_metrics;
  7
  8use assertions::display_error_row;
  9use instance::{ExampleInstance, JudgeOutput, RunOutput, run_git};
 10pub(crate) use tool_metrics::*;
 11
 12use ::fs::RealFs;
 13use anyhow::anyhow;
 14use clap::Parser;
 15use client::{Client, ProxySettings, UserStore};
 16use collections::{HashMap, HashSet};
 17use extension::ExtensionHostProxy;
 18use futures::future;
 19use gpui::http_client::{Uri, read_proxy_from_env};
 20use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 21use gpui_tokio::Tokio;
 22use language::LanguageRegistry;
 23use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 24use node_runtime::{NodeBinaryOptions, NodeRuntime};
 25use project::Project;
 26use project::project_settings::ProjectSettings;
 27use prompt_store::PromptBuilder;
 28use release_channel::AppVersion;
 29use reqwest_client::ReqwestClient;
 30use settings::{Settings, SettingsStore};
 31use std::cell::RefCell;
 32use std::collections::VecDeque;
 33use std::env;
 34use std::path::{Path, PathBuf};
 35use std::rc::Rc;
 36use std::sync::Arc;
 37use util::ResultExt as _;
 38
 39#[derive(Parser, Debug)]
 40#[command(name = "eval", disable_version_flag = true)]
 41struct Args {
 42    /// Runs all examples and threads that contain these substrings. If unspecified, all examples and threads are run.
 43    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 44    filter: Vec<String>,
 45    /// Model to use (default: "claude-3-7-sonnet-latest")
 46    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 47    model: String,
 48    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 49    languages: Vec<String>,
 50    /// How many times to run each example.
 51    #[arg(long, default_value = "1")]
 52    repetitions: usize,
 53    /// Maximum number of examples to run concurrently.
 54    #[arg(long, default_value = "10")]
 55    concurrency: usize,
 56}
 57
 58fn main() {
 59    env_logger::init();
 60
 61    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 62    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 63    let session_id = uuid::Uuid::new_v4().to_string();
 64    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 65    let run_id = match env::var("GITHUB_RUN_ID") {
 66        Ok(run_id) => format!("github/{}", run_id),
 67        Err(_) => format!("local/{}", run_timestamp),
 68    };
 69
 70    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 71        .parent()
 72        .unwrap()
 73        .parent()
 74        .unwrap()
 75        .canonicalize()
 76        .unwrap();
 77    let eval_crate_dir = root_dir.join("crates").join("eval");
 78    let repos_dir = eval_crate_dir.join("repos");
 79    let worktrees_dir = eval_crate_dir.join("worktrees");
 80    let examples_dir = eval_crate_dir.join("src").join("examples");
 81    let run_dir = eval_crate_dir
 82        .join("runs")
 83        .join(format!("{}", run_timestamp));
 84    std::fs::create_dir_all(&run_dir).unwrap();
 85    std::fs::create_dir_all(&repos_dir).unwrap();
 86    std::fs::create_dir_all(&worktrees_dir).unwrap();
 87    std::fs::create_dir_all(&examples_dir).unwrap();
 88    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 89
 90    let zed_commit_sha = commit_sha_for_path(&root_dir);
 91    let zed_branch_name = git_branch_for_path(&root_dir);
 92    let args = Args::parse();
 93    let languages: HashSet<String> = args.languages.into_iter().collect();
 94
 95    let http_client = Arc::new(ReqwestClient::new());
 96    let app = Application::headless().with_http_client(http_client.clone());
 97    let all_threads = examples::all(&examples_dir);
 98
 99    app.run(move |cx| {
100        let app_state = init(cx);
101
102        let telemetry = app_state.client.telemetry();
103        telemetry.start(system_id, installation_id, session_id, cx);
104
105        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
106            && telemetry.has_checksum_seed();
107        if enable_telemetry {
108            println!("Telemetry enabled");
109            telemetry::event!(
110                "Agent Eval Started",
111                zed_commit_sha = zed_commit_sha,
112                zed_branch_name = zed_branch_name,
113                run_id = run_id,
114            );
115        }
116
117        let mut cumulative_tool_metrics = ToolMetrics::default();
118
119        let model_registry = LanguageModelRegistry::read_global(cx);
120        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
121        let model_provider_id = model.provider_id();
122        let model_provider = model_registry.provider(&model_provider_id).unwrap();
123
124        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
125            registry.set_default_model(
126                Some(ConfiguredModel {
127                    provider: model_provider.clone(),
128                    model: model.clone(),
129                }),
130                cx,
131            );
132        });
133
134        let authenticate_task = model_provider.authenticate(cx);
135
136        cx.spawn(async move |cx| {
137            authenticate_task.await.unwrap();
138
139            let mut examples = Vec::new();
140
141            const COLORS: [&str; 12] = [
142                "\x1b[31m", // Red
143                "\x1b[32m", // Green
144                "\x1b[33m", // Yellow
145                "\x1b[34m", // Blue
146                "\x1b[35m", // Magenta
147                "\x1b[36m", // Cyan
148                "\x1b[91m", // Bright Red
149                "\x1b[92m", // Bright Green
150                "\x1b[93m", // Bright Yellow
151                "\x1b[94m", // Bright Blue
152                "\x1b[95m", // Bright Magenta
153                "\x1b[96m", // Bright Cyan
154            ];
155
156            let mut skipped = Vec::new();
157
158            for thread in all_threads {
159                let meta = thread.meta();
160                if !args.filter.is_empty() && !args.filter.iter().any(|sub| meta.name.contains(sub))
161                {
162                    skipped.push(meta.name);
163                    continue;
164                }
165
166                if meta.language_server.map_or(false, |language| {
167                    !languages.contains(&language.file_extension)
168                }) {
169                    skipped.push(meta.name);
170                    continue;
171                }
172
173                // TODO: This creates a worktree per repetition. Ideally these examples should
174                // either be run sequentially on the same worktree, or reuse worktrees when there
175                // are more examples to run than the concurrency limit.
176                for repetition_number in 0..args.repetitions {
177                    let example_instance = ExampleInstance::new(
178                        thread.clone(),
179                        &repos_dir,
180                        &run_dir,
181                        &worktrees_dir,
182                        repetition_number,
183                    );
184
185                    examples.push(example_instance);
186                }
187            }
188
189            if !skipped.is_empty() {
190                println!("Skipped threads: {}", skipped.join(", "));
191            }
192
193            if examples.is_empty() {
194                eprintln!("Filter matched no examples");
195                return cx.update(|cx| cx.quit());
196            }
197
198            let mut repo_urls = HashSet::default();
199            let mut clone_tasks = Vec::new();
200
201            let max_name_width = examples
202                .iter()
203                .map(|e| e.worktree_name().len())
204                .max()
205                .unwrap_or(0);
206
207            for (i, example_instance) in examples.iter_mut().enumerate() {
208                let color = COLORS[i % COLORS.len()].to_string();
209                example_instance.set_log_prefix_style(&color, max_name_width);
210
211                println!(
212                    "{}Logging to: {}",
213                    example_instance.log_prefix,
214                    example_instance.run_directory.display()
215                );
216
217                let repo_url = example_instance.repo_url();
218                if repo_urls.insert(repo_url.clone()) {
219                    let repo_path = example_instance.repo_path.clone();
220
221                    if !repo_path.join(".git").is_dir() {
222                        println!(
223                            "{:<width$} < {}",
224                            "↓ Cloning",
225                            repo_url,
226                            width = max_name_width
227                        );
228
229                        let git_task = cx.spawn(async move |_cx| {
230                            std::fs::create_dir_all(&repo_path)?;
231                            run_git(&repo_path, &["init"]).await?;
232                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
233                        });
234
235                        clone_tasks.push(git_task);
236                    } else {
237                        println!(
238                            "{:<width$}  < {}",
239                            "✔︎ Already cloned",
240                            repo_url,
241                            width = max_name_width
242                        );
243
244                        let actual_origin =
245                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
246                        if actual_origin != repo_url {
247                            return Err(anyhow!(
248                                "remote origin {} does not match expected origin {}",
249                                actual_origin,
250                                repo_url,
251                            ));
252                        }
253                    }
254                }
255            }
256
257            future::join_all(clone_tasks).await;
258
259            for example_instance in examples.iter_mut() {
260                example_instance.fetch().await?;
261            }
262
263            let examples = Rc::new(RefCell::new(VecDeque::from(examples)));
264            let results_by_example_name = Rc::new(RefCell::new(HashMap::default()));
265
266            future::join_all((0..args.concurrency).map(|_| {
267                let app_state = app_state.clone();
268                let model = model.clone();
269                let zed_commit_sha = zed_commit_sha.clone();
270                let zed_branch_name = zed_branch_name.clone();
271                let run_id = run_id.clone();
272                let examples = examples.clone();
273                let results = results_by_example_name.clone();
274                cx.spawn(async move |cx| {
275                    loop {
276                        let Some(mut example) = examples.borrow_mut().pop_front() else {
277                            break;
278                        };
279                        let result = async {
280                            example.setup().await?;
281                            let run_output = cx
282                                .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
283                                .await?;
284                            let judge_output = judge_example(
285                                example.clone(),
286                                model.clone(),
287                                &zed_commit_sha,
288                                &zed_branch_name,
289                                &run_id,
290                                &run_output,
291                                enable_telemetry,
292                                cx,
293                            )
294                            .await;
295                            anyhow::Ok((run_output, judge_output))
296                        }
297                        .await;
298                        results
299                            .borrow_mut()
300                            .entry(example.name.clone())
301                            .or_insert(Vec::new())
302                            .push((example.clone(), result));
303                    }
304                })
305            }))
306            .await;
307
308            print_h1("EVAL RESULTS");
309
310            let mut diff_scores = Vec::new();
311            let mut thread_scores = Vec::new();
312            let mut programmatic_scores = Vec::new();
313            let mut error_count = 0;
314
315            for (example_name, results) in results_by_example_name.borrow_mut().iter_mut() {
316                print_h2(&example_name);
317
318                results.sort_unstable_by_key(|(example, _)| example.repetition);
319                let mut example_cumulative_tool_metrics = ToolMetrics::default();
320
321                let mut table_rows = String::new();
322
323                for (example, result) in results.iter() {
324                    match result {
325                        Err(err) => {
326                            display_error_row(
327                                &mut table_rows,
328                                example.repetition,
329                                err.to_string(),
330                            )?;
331                            error_count += 1;
332                        }
333                        Ok((run_output, judge_output)) => {
334                            cumulative_tool_metrics.merge(&run_output.tool_metrics);
335                            example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
336
337                            if !run_output.programmatic_assertions.total_count() > 0 {
338                                for assertion in &run_output.programmatic_assertions.ran {
339                                    assertions::display_table_row(
340                                        &mut table_rows,
341                                        example.repetition,
342                                        assertion,
343                                    )?;
344                                }
345
346                                programmatic_scores
347                                    .push(run_output.programmatic_assertions.passed_percentage())
348                            }
349
350                            if !judge_output.diff.is_empty() {
351                                diff_scores.push(judge_output.diff.passed_percentage());
352
353                                for assertion in &judge_output.diff.ran {
354                                    assertions::display_table_row(
355                                        &mut table_rows,
356                                        example.repetition,
357                                        assertion,
358                                    )?;
359                                }
360                            }
361
362                            if !judge_output.thread.is_empty() {
363                                thread_scores.push(judge_output.thread.passed_percentage());
364
365                                for assertion in &judge_output.thread.ran {
366                                    assertions::display_table_row(
367                                        &mut table_rows,
368                                        example.repetition,
369                                        assertion,
370                                    )?;
371                                }
372                            }
373                        }
374                    }
375                }
376
377                if !table_rows.is_empty() {
378                    assertions::print_table_header();
379                    print!("{}", table_rows);
380
381                    assertions::print_table_divider();
382
383                    for (example, result) in results.iter() {
384                        if let Ok((run_output, judge_output)) = result {
385                            assertions::print_table_round_summary(
386                                &example.repetition.to_string(),
387                                [
388                                    &run_output.programmatic_assertions,
389                                    &judge_output.diff,
390                                    &judge_output.thread,
391                                ]
392                                .into_iter(),
393                            )
394                        }
395                    }
396
397                    assertions::print_table_divider();
398
399                    assertions::print_table_round_summary(
400                        "avg",
401                        results.iter().flat_map(|(_, result)| {
402                            result.iter().flat_map(|(run_output, judge_output)| {
403                                [
404                                    &run_output.programmatic_assertions,
405                                    &judge_output.diff,
406                                    &judge_output.thread,
407                                ]
408                                .into_iter()
409                            })
410                        }),
411                    );
412
413                    assertions::print_table_footer();
414                }
415
416                if !example_cumulative_tool_metrics.is_empty() {
417                    println!("{}", &example_cumulative_tool_metrics);
418                }
419            }
420
421            if results_by_example_name.borrow().len() > 1 {
422                print_h1("AGGREGATE");
423
424                if error_count > 0 {
425                    println!("\n{error_count} examples failed to run!");
426                }
427
428                let programmatic_score_count = programmatic_scores.len();
429                if programmatic_score_count > 0 {
430                    let average_programmatic_score = (programmatic_scores.into_iter().sum::<f32>()
431                        / (programmatic_score_count as f32))
432                        .floor();
433                    println!("Average programmatic score: {average_programmatic_score}%");
434                }
435
436                let diff_score_count = diff_scores.len();
437                if diff_score_count > 0 {
438                    let average_diff_score =
439                        (diff_scores.into_iter().sum::<f32>() / (diff_score_count as f32)).floor();
440                    println!("Average diff score: {average_diff_score}%");
441                }
442
443                let thread_score_count = thread_scores.len();
444
445                if thread_score_count > 0 {
446                    let average_thread_score = (thread_scores.into_iter().sum::<f32>()
447                        / (thread_score_count as f32))
448                        .floor();
449                    println!("Average thread score: {average_thread_score}%");
450                }
451
452                println!("");
453
454                print_h2("CUMULATIVE TOOL METRICS");
455                println!("{}", cumulative_tool_metrics);
456            }
457
458            app_state.client.telemetry().flush_events().await;
459
460            cx.update(|cx| cx.quit())
461        })
462        .detach_and_log_err(cx);
463    });
464}
465
466/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
467pub struct AgentAppState {
468    pub languages: Arc<LanguageRegistry>,
469    pub client: Arc<Client>,
470    pub user_store: Entity<UserStore>,
471    pub fs: Arc<dyn fs::Fs>,
472    pub node_runtime: NodeRuntime,
473
474    // Additional fields not present in `workspace::AppState`.
475    pub prompt_builder: Arc<PromptBuilder>,
476}
477
478pub fn init(cx: &mut App) -> Arc<AgentAppState> {
479    release_channel::init(SemanticVersion::default(), cx);
480    gpui_tokio::init(cx);
481
482    let mut settings_store = SettingsStore::new(cx);
483    settings_store
484        .set_default_settings(settings::default_settings().as_ref(), cx)
485        .unwrap();
486    cx.set_global(settings_store);
487    client::init_settings(cx);
488
489    // Set User-Agent so we can download language servers from GitHub
490    let user_agent = format!(
491        "Zed/{} ({}; {})",
492        AppVersion::global(cx),
493        std::env::consts::OS,
494        std::env::consts::ARCH
495    );
496    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
497    let proxy_url = proxy_str
498        .as_ref()
499        .and_then(|input| input.parse::<Uri>().ok())
500        .or_else(read_proxy_from_env);
501    let http = {
502        let _guard = Tokio::handle(cx).enter();
503
504        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
505            .expect("could not start HTTP client")
506    };
507    cx.set_http_client(Arc::new(http));
508
509    Project::init_settings(cx);
510
511    let client = Client::production(cx);
512    cx.set_http_client(client.http_client().clone());
513
514    let git_binary_path = None;
515    let fs = Arc::new(RealFs::new(
516        git_binary_path,
517        cx.background_executor().clone(),
518    ));
519
520    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
521    languages.set_language_server_download_dir(paths::languages_dir().clone());
522    let languages = Arc::new(languages);
523
524    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
525
526    extension::init(cx);
527
528    let (tx, rx) = async_watch::channel(None);
529    cx.observe_global::<SettingsStore>(move |cx| {
530        let settings = &ProjectSettings::get_global(cx).node;
531        let options = NodeBinaryOptions {
532            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
533            allow_binary_download: true,
534            use_paths: settings.path.as_ref().map(|node_path| {
535                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
536                let npm_path = settings
537                    .npm_path
538                    .as_ref()
539                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
540                (
541                    node_path.clone(),
542                    npm_path.unwrap_or_else(|| {
543                        let base_path = PathBuf::new();
544                        node_path.parent().unwrap_or(&base_path).join("npm")
545                    }),
546                )
547            }),
548        };
549        tx.send(Some(options)).log_err();
550    })
551    .detach();
552    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
553
554    let extension_host_proxy = ExtensionHostProxy::global(cx);
555
556    language::init(cx);
557    language_extension::init(extension_host_proxy.clone(), languages.clone());
558    language_model::init(client.clone(), cx);
559    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
560    languages::init(languages.clone(), node_runtime.clone(), cx);
561    assistant_tools::init(client.http_client().clone(), cx);
562    context_server::init(cx);
563    prompt_store::init(cx);
564    let stdout_is_a_pty = false;
565    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
566    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
567
568    SettingsStore::update_global(cx, |store, cx| {
569        store.set_user_settings(include_str!("../runner_settings.json"), cx)
570    })
571    .unwrap();
572
573    Arc::new(AgentAppState {
574        languages,
575        client,
576        user_store,
577        fs,
578        node_runtime,
579        prompt_builder,
580    })
581}
582
583pub fn find_model(
584    model_name: &str,
585    model_registry: &LanguageModelRegistry,
586    cx: &App,
587) -> anyhow::Result<Arc<dyn LanguageModel>> {
588    let model = model_registry
589        .available_models(cx)
590        .find(|model| model.id().0 == model_name);
591
592    let Some(model) = model else {
593        return Err(anyhow!(
594            "No language model named {} was available. Available models: {}",
595            model_name,
596            model_registry
597                .available_models(cx)
598                .map(|model| model.id().0.clone())
599                .collect::<Vec<_>>()
600                .join(", ")
601        ));
602    };
603
604    Ok(model)
605}
606
607pub fn commit_sha_for_path(repo_path: &Path) -> String {
608    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
609}
610
611pub fn git_branch_for_path(repo_path: &Path) -> String {
612    match std::env::var("GITHUB_REF_NAME") {
613        Ok(branch) => branch,
614        Err(_) => {
615            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
616                .unwrap_or_else(|_| "unknown".to_string())
617        }
618    }
619}
620
621async fn judge_example(
622    example: ExampleInstance,
623    model: Arc<dyn LanguageModel>,
624    zed_commit_sha: &str,
625    zed_branch_name: &str,
626    run_id: &str,
627    run_output: &RunOutput,
628    enable_telemetry: bool,
629    cx: &AsyncApp,
630) -> JudgeOutput {
631    let judge_output = example.judge(model.clone(), &run_output, cx).await;
632
633    if enable_telemetry {
634        telemetry::event!(
635            "Agent Example Evaluated",
636            zed_commit_sha = zed_commit_sha,
637            zed_branch_name = zed_branch_name,
638            run_id = run_id,
639            example_name = example.name.clone(),
640            example_repetition = example.repetition,
641            diff_evaluation = judge_output.diff.clone(),
642            thread_evaluation = judge_output.thread.clone(),
643            tool_metrics = run_output.tool_metrics,
644            response_count = run_output.response_count,
645            token_usage = run_output.token_usage,
646            model = model.telemetry_id(),
647            model_provider = model.provider_id().to_string(),
648            repository_url = example.repo_url(),
649            repository_revision = example.revision(),
650            diagnostic_summary_before = run_output.diagnostic_summary_before,
651            diagnostic_summary_after = run_output.diagnostic_summary_after,
652            diagnostics_before = run_output.diagnostics_before,
653            diagnostics_after = run_output.diagnostics_after,
654        );
655    }
656
657    judge_output
658}
659
660const HEADER_WIDTH: usize = 65;
661
662fn print_h1(header: &str) {
663    println!("\n\n{:=^HEADER_WIDTH$}", "");
664    println!("{:^HEADER_WIDTH$}", header);
665    println!("{:=^HEADER_WIDTH$}\n", "");
666}
667
668fn print_h2(header: &str) {
669    println!("\n{:-^HEADER_WIDTH$}", "");
670    println!("{:^HEADER_WIDTH$}", header);
671    println!("{:-^HEADER_WIDTH$}\n", "");
672}