eval.rs

  1mod example;
  2mod ids;
  3mod tool_metrics;
  4
  5pub(crate) use example::*;
  6use parking_lot::Mutex;
  7pub(crate) use tool_metrics::*;
  8
  9use ::fs::RealFs;
 10use anyhow::{Result, anyhow};
 11use clap::Parser;
 12use client::{Client, ProxySettings, UserStore};
 13use collections::{HashMap, HashSet};
 14use extension::ExtensionHostProxy;
 15use futures::future;
 16use gpui::http_client::{Uri, read_proxy_from_env};
 17use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 18use gpui_tokio::Tokio;
 19use language::LanguageRegistry;
 20use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 21use node_runtime::{NodeBinaryOptions, NodeRuntime};
 22use project::Project;
 23use project::project_settings::ProjectSettings;
 24use prompt_store::PromptBuilder;
 25use release_channel::AppVersion;
 26use reqwest_client::ReqwestClient;
 27use settings::{Settings, SettingsStore};
 28use std::collections::VecDeque;
 29use std::env;
 30use std::path::{Path, PathBuf};
 31use std::sync::Arc;
 32use util::ResultExt as _;
 33
 34#[derive(Parser, Debug)]
 35#[command(name = "eval", disable_version_flag = true)]
 36struct Args {
 37    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 38    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 39    examples: Vec<String>,
 40    /// Model to use (default: "claude-3-7-sonnet-latest")
 41    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 42    model: String,
 43    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 44    languages: Vec<String>,
 45    /// How many times to run each example.
 46    #[arg(long, default_value = "1")]
 47    repetitions: usize,
 48    /// Maximum number of examples to run concurrently.
 49    #[arg(long, default_value = "10")]
 50    concurrency: usize,
 51}
 52
 53fn main() {
 54    env_logger::init();
 55
 56    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 57    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 58    let session_id = uuid::Uuid::new_v4().to_string();
 59    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 60    let run_id = match env::var("GITHUB_RUN_ID") {
 61        Ok(run_id) => format!("github/{}", run_id),
 62        Err(_) => format!("local/{}", run_timestamp),
 63    };
 64
 65    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 66        .parent()
 67        .unwrap()
 68        .parent()
 69        .unwrap();
 70    let eval_crate_dir = root_dir.join("crates/eval");
 71    let repos_dir = eval_crate_dir.join("repos");
 72    let worktrees_dir = eval_crate_dir.join("worktrees");
 73    let examples_dir = eval_crate_dir.join("examples");
 74    let runs_dir = eval_crate_dir.join("runs");
 75    let run_dir = runs_dir.join(format!("{}", run_timestamp));
 76    std::fs::create_dir_all(&run_dir).unwrap();
 77    std::fs::create_dir_all(&repos_dir).unwrap();
 78    std::fs::create_dir_all(&worktrees_dir).unwrap();
 79    std::fs::create_dir_all(&examples_dir).unwrap();
 80    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 81
 82    let zed_commit_sha = commit_sha_for_path(root_dir);
 83    let zed_branch_name = git_branch_for_path(root_dir);
 84    let args = Args::parse();
 85    let all_available_examples = list_all_examples(&examples_dir).unwrap();
 86
 87    let example_paths = all_available_examples
 88        .iter()
 89        .filter_map(|example_path| {
 90            let name = example_path.file_name()?.to_string_lossy();
 91            if args.examples.is_empty()
 92                || args
 93                    .examples
 94                    .iter()
 95                    .any(|name_substring| name.contains(name_substring))
 96            {
 97                Some(example_path.clone())
 98            } else {
 99                None
100            }
101        })
102        .collect::<Vec<_>>();
103
104    let http_client = Arc::new(ReqwestClient::new());
105    let app = Application::headless().with_http_client(http_client.clone());
106
107    app.run(move |cx| {
108        let app_state = init(cx);
109
110        let telemetry = app_state.client.telemetry();
111        telemetry.start(system_id, installation_id, session_id, cx);
112
113        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
114            && telemetry.has_checksum_seed();
115        if enable_telemetry {
116            println!("Telemetry enabled");
117            telemetry::event!(
118                "Agent Eval Started",
119                zed_commit_sha = zed_commit_sha,
120                zed_branch_name = zed_branch_name,
121                run_id = run_id,
122            );
123        }
124
125        let mut cumulative_tool_metrics = ToolMetrics::default();
126
127        let model_registry = LanguageModelRegistry::read_global(cx);
128        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
129        let model_provider_id = model.provider_id();
130        let model_provider = model_registry.provider(&model_provider_id).unwrap();
131
132        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
133            registry.set_default_model(
134                Some(ConfiguredModel {
135                    provider: model_provider.clone(),
136                    model: model.clone(),
137                }),
138                cx,
139            );
140        });
141
142        let authenticate_task = model_provider.authenticate(cx);
143
144        cx.spawn(async move |cx| {
145            authenticate_task.await.unwrap();
146
147            let mut examples = Vec::new();
148
149            const COLORS: [&str; 12] = [
150                "\x1b[31m", // Red
151                "\x1b[32m", // Green
152                "\x1b[33m", // Yellow
153                "\x1b[34m", // Blue
154                "\x1b[35m", // Magenta
155                "\x1b[36m", // Cyan
156                "\x1b[91m", // Bright Red
157                "\x1b[92m", // Bright Green
158                "\x1b[93m", // Bright Yellow
159                "\x1b[94m", // Bright Blue
160                "\x1b[95m", // Bright Magenta
161                "\x1b[96m", // Bright Cyan
162            ];
163
164            let mut skipped = Vec::new();
165
166            for example_path in &example_paths {
167                let example = Example::load_from_directory(
168                    example_path,
169                    &run_dir,
170                    &worktrees_dir,
171                    &repos_dir,
172                )?;
173
174                if !example
175                    .base
176                    .language_extension
177                    .as_ref()
178                    .map_or(false, |lang| args.languages.contains(lang))
179                {
180                    skipped.push(example.name);
181                    continue;
182                }
183
184                examples.extend(example.repeat(args.repetitions));
185            }
186
187            println!("Skipped examples: {}\n", skipped.join(", "));
188
189            if examples.is_empty() {
190                eprintln!("Filter matched no examples");
191                return cx.update(|cx| cx.quit());
192            }
193
194            let mut repo_urls = HashSet::default();
195            let mut clone_tasks = Vec::new();
196
197            let max_name_width = examples
198                .iter()
199                .map(|e| e.repetition_name().len())
200                .max()
201                .unwrap_or(0);
202            for (i, example) in examples.iter_mut().enumerate() {
203                let color = COLORS[i % COLORS.len()].to_string();
204                example.set_log_prefix_style(&color, max_name_width);
205
206                println!(
207                    "{}Logging to: {}",
208                    example.log_prefix,
209                    example.run_directory_path().display()
210                );
211
212                let repo_url = example.base.url.clone();
213                if repo_urls.insert(repo_url.clone()) {
214                    let repo_path = example.repo_path.clone();
215
216                    if !repo_path.join(".git").is_dir() {
217                        println!(
218                            "{:<width$} < {}",
219                            "↓ Cloning",
220                            repo_url,
221                            width = max_name_width
222                        );
223
224                        let git_task = cx.spawn(async move |_cx| {
225                            std::fs::create_dir_all(&repo_path)?;
226                            run_git(&repo_path, &["init"]).await?;
227                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
228                        });
229
230                        clone_tasks.push(git_task);
231                    } else {
232                        println!(
233                            "{:<width$}  < {}",
234                            "✔︎ Already cloned",
235                            repo_url,
236                            width = max_name_width
237                        );
238
239                        let actual_origin =
240                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
241                        if actual_origin != repo_url {
242                            return Err(anyhow!(
243                                "remote origin {} does not match expected origin {}",
244                                actual_origin,
245                                repo_url,
246                            ));
247                        }
248                    }
249                }
250            }
251
252            future::join_all(clone_tasks).await;
253
254            for example in examples.iter_mut() {
255                example.fetch().await?;
256            }
257
258            let examples = Arc::new(Mutex::new(VecDeque::from(examples)));
259            let results_by_example_name = Arc::new(Mutex::new(HashMap::default()));
260
261            future::join_all((0..args.concurrency).map(|_| {
262                let app_state = app_state.clone();
263                let model = model.clone();
264                let zed_commit_sha = zed_commit_sha.clone();
265                let zed_branch_name = zed_branch_name.clone();
266                let run_id = run_id.clone();
267                let examples = examples.clone();
268                let results = results_by_example_name.clone();
269                cx.spawn(async move |cx| {
270                    loop {
271                        let Some(mut example) = examples.lock().pop_front() else {
272                            break;
273                        };
274                        let result = async {
275                            example.setup().await?;
276                            let run_output = cx
277                                .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
278                                .await?;
279                            let judge_output = judge_example(
280                                example.clone(),
281                                model.clone(),
282                                &zed_commit_sha,
283                                &zed_branch_name,
284                                &run_id,
285                                &run_output,
286                                enable_telemetry,
287                                cx,
288                            )
289                            .await;
290                            anyhow::Ok((run_output, judge_output))
291                        }
292                        .await;
293                        results
294                            .lock()
295                            .entry(example.name.clone())
296                            .or_insert(Vec::new())
297                            .push((example.clone(), result));
298                    }
299                })
300            }))
301            .await;
302
303            println!("\n\n");
304            print_header("EVAL RESULTS");
305
306            let mut diff_scores = Vec::new();
307            let mut thread_scores = Vec::new();
308            let mut error_count = 0;
309
310            for (example_name, results) in results_by_example_name.lock().iter_mut() {
311                print_header(&example_name);
312
313                results.sort_unstable_by_key(|(example, _)| example.repetition);
314                let mut example_cumulative_tool_metrics = ToolMetrics::default();
315
316                println!("┌───────┬──────┬────────┐");
317                println!("│ Round │ Diff │ Thread │");
318                println!("├───────┼──────┼────────┤");
319                for (example, result) in results {
320                    let run_dir_path = example.run_directory_path();
321                    let relative_run_dir_path = run_dir_path.strip_prefix(root_dir).unwrap();
322
323                    match result {
324                        Err(err) => {
325                            println!(
326                                "|{:^7}{:^6}{:^8}{:?}{}",
327                                example.repetition,
328                                "N/A",
329                                "N/A",
330                                err,
331                                relative_run_dir_path.display()
332                            );
333                            error_count += 1;
334                        }
335                        Ok((run_output, judge_result)) => {
336                            cumulative_tool_metrics.merge(&run_output.tool_metrics);
337                            example_cumulative_tool_metrics.merge(&run_output.tool_metrics);
338
339                            match judge_result {
340                                Ok(judge_output) => {
341                                    diff_scores.push(judge_output.diff.score());
342                                    thread_scores.push(judge_output.thread.score());
343                                    println!(
344                                        "|{:^7}{:^6}{:^8}{}",
345                                        example.repetition,
346                                        format!("{}%", judge_output.diff.score()),
347                                        format!("{}%", judge_output.thread.score()),
348                                        relative_run_dir_path.display()
349                                    );
350                                }
351                                Err(err) => {
352                                    println!(
353                                        "|{:^7}{:^6}{:^8}{:?}{}",
354                                        example.repetition,
355                                        "N/A",
356                                        "N/A",
357                                        err,
358                                        relative_run_dir_path.display()
359                                    );
360                                }
361                            }
362                        }
363                    }
364                }
365
366                println!("└───────┴──────┴────────┘");
367                println!("{}", example_cumulative_tool_metrics);
368            }
369
370            let diff_score_count = diff_scores.len();
371            let average_diff_score = diff_scores
372                .into_iter()
373                .map(|score| score as f32)
374                .sum::<f32>()
375                / (diff_score_count as f32);
376
377            if error_count > 0 {
378                println!("\n{error_count} examples failed to run!");
379            }
380
381            println!("\nAverage code diff score: {average_diff_score}");
382
383            let thread_score_count = thread_scores.len();
384            let average_thread_score = thread_scores
385                .into_iter()
386                .map(|score| score as f32)
387                .sum::<f32>()
388                / (thread_score_count as f32);
389
390            println!("\nAverage thread score: {average_thread_score}");
391
392            print_header("CUMULATIVE TOOL METRICS");
393            println!("{}", cumulative_tool_metrics);
394
395            app_state.client.telemetry().flush_events().await;
396
397            cx.update(|cx| cx.quit())
398        })
399        .detach_and_log_err(cx);
400    });
401}
402
403fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
404    let path = std::fs::canonicalize(examples_dir).unwrap();
405    let entries = std::fs::read_dir(path).unwrap();
406    let mut result_paths = Vec::new();
407    for entry in entries {
408        let entry = entry?;
409        let path = entry.path();
410        if path.is_dir() {
411            result_paths.push(path);
412        }
413    }
414    Ok(result_paths)
415}
416
417/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
418pub struct AgentAppState {
419    pub languages: Arc<LanguageRegistry>,
420    pub client: Arc<Client>,
421    pub user_store: Entity<UserStore>,
422    pub fs: Arc<dyn fs::Fs>,
423    pub node_runtime: NodeRuntime,
424
425    // Additional fields not present in `workspace::AppState`.
426    pub prompt_builder: Arc<PromptBuilder>,
427}
428
429pub fn init(cx: &mut App) -> Arc<AgentAppState> {
430    release_channel::init(SemanticVersion::default(), cx);
431    gpui_tokio::init(cx);
432
433    let mut settings_store = SettingsStore::new(cx);
434    settings_store
435        .set_default_settings(settings::default_settings().as_ref(), cx)
436        .unwrap();
437    cx.set_global(settings_store);
438    client::init_settings(cx);
439
440    // Set User-Agent so we can download language servers from GitHub
441    let user_agent = format!(
442        "Zed/{} ({}; {})",
443        AppVersion::global(cx),
444        std::env::consts::OS,
445        std::env::consts::ARCH
446    );
447    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
448    let proxy_url = proxy_str
449        .as_ref()
450        .and_then(|input| input.parse::<Uri>().ok())
451        .or_else(read_proxy_from_env);
452    let http = {
453        let _guard = Tokio::handle(cx).enter();
454
455        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
456            .expect("could not start HTTP client")
457    };
458    cx.set_http_client(Arc::new(http));
459
460    Project::init_settings(cx);
461
462    let client = Client::production(cx);
463    cx.set_http_client(client.http_client().clone());
464
465    let git_binary_path = None;
466    let fs = Arc::new(RealFs::new(
467        git_binary_path,
468        cx.background_executor().clone(),
469    ));
470
471    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
472    languages.set_language_server_download_dir(paths::languages_dir().clone());
473    let languages = Arc::new(languages);
474
475    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
476
477    extension::init(cx);
478
479    let (tx, rx) = async_watch::channel(None);
480    cx.observe_global::<SettingsStore>(move |cx| {
481        let settings = &ProjectSettings::get_global(cx).node;
482        let options = NodeBinaryOptions {
483            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
484            allow_binary_download: true,
485            use_paths: settings.path.as_ref().map(|node_path| {
486                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
487                let npm_path = settings
488                    .npm_path
489                    .as_ref()
490                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
491                (
492                    node_path.clone(),
493                    npm_path.unwrap_or_else(|| {
494                        let base_path = PathBuf::new();
495                        node_path.parent().unwrap_or(&base_path).join("npm")
496                    }),
497                )
498            }),
499        };
500        tx.send(Some(options)).log_err();
501    })
502    .detach();
503    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
504
505    let extension_host_proxy = ExtensionHostProxy::global(cx);
506
507    language::init(cx);
508    language_extension::init(extension_host_proxy.clone(), languages.clone());
509    language_model::init(client.clone(), cx);
510    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
511    languages::init(languages.clone(), node_runtime.clone(), cx);
512    assistant_tools::init(client.http_client().clone(), cx);
513    context_server::init(cx);
514    prompt_store::init(cx);
515    let stdout_is_a_pty = false;
516    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
517    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
518
519    SettingsStore::update_global(cx, |store, cx| {
520        store.set_user_settings(include_str!("../runner_settings.json"), cx)
521    })
522    .unwrap();
523
524    Arc::new(AgentAppState {
525        languages,
526        client,
527        user_store,
528        fs,
529        node_runtime,
530        prompt_builder,
531    })
532}
533
534pub fn find_model(
535    model_name: &str,
536    model_registry: &LanguageModelRegistry,
537    cx: &App,
538) -> anyhow::Result<Arc<dyn LanguageModel>> {
539    let model = model_registry
540        .available_models(cx)
541        .find(|model| model.id().0 == model_name);
542
543    let Some(model) = model else {
544        return Err(anyhow!(
545            "No language model named {} was available. Available models: {}",
546            model_name,
547            model_registry
548                .available_models(cx)
549                .map(|model| model.id().0.clone())
550                .collect::<Vec<_>>()
551                .join(", ")
552        ));
553    };
554
555    Ok(model)
556}
557
558pub fn commit_sha_for_path(repo_path: &Path) -> String {
559    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
560}
561
562pub fn git_branch_for_path(repo_path: &Path) -> String {
563    match std::env::var("GITHUB_REF_NAME") {
564        Ok(branch) => branch,
565        Err(_) => {
566            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
567                .unwrap_or_else(|_| "unknown".to_string())
568        }
569    }
570}
571
572async fn judge_example(
573    example: Example,
574    model: Arc<dyn LanguageModel>,
575    zed_commit_sha: &str,
576    zed_branch_name: &str,
577    run_id: &str,
578    run_output: &RunOutput,
579    enable_telemetry: bool,
580    cx: &AsyncApp,
581) -> Result<JudgeOutput> {
582    let judge_output = example.judge(model.clone(), &run_output, cx).await;
583
584    let diff_evaluation;
585    let thread_evaluation;
586    if let Ok(output) = judge_output.as_ref() {
587        diff_evaluation = Some(output.diff.clone());
588        thread_evaluation = Some(output.thread.clone());
589    } else {
590        diff_evaluation = None;
591        thread_evaluation = None;
592    }
593
594    if enable_telemetry {
595        telemetry::event!(
596            "Agent Example Evaluated",
597            zed_commit_sha = zed_commit_sha,
598            zed_branch_name = zed_branch_name,
599            run_id = run_id,
600            example_name = example.name.clone(),
601            example_repetition = example.repetition,
602            diff_evaluation = diff_evaluation,
603            thread_evaluation = thread_evaluation,
604            tool_metrics = run_output.tool_metrics,
605            response_count = run_output.response_count,
606            token_usage = run_output.token_usage,
607            model = model.telemetry_id(),
608            model_provider = model.provider_id().to_string(),
609            repository_url = example.base.url.clone(),
610            repository_revision = example.base.revision.clone(),
611            diagnostic_summary_before = run_output.diagnostic_summary_before,
612            diagnostic_summary_after = run_output.diagnostic_summary_after,
613            diagnostics_before = run_output.diagnostics_before,
614            diagnostics_after = run_output.diagnostics_after,
615        );
616    }
617
618    judge_output
619}
620
621fn print_header(header: &str) {
622    println!("\n========================================");
623    println!("{:^40}", header);
624    println!("========================================\n");
625}