eval.rs

  1mod example;
  2mod ids;
  3mod tool_metrics;
  4
  5pub(crate) use example::*;
  6pub(crate) use tool_metrics::*;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use client::{Client, ProxySettings, UserStore};
 12use collections::HashSet;
 13use extension::ExtensionHostProxy;
 14use futures::{StreamExt, future};
 15use gpui::http_client::{Uri, read_proxy_from_env};
 16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 17use gpui_tokio::Tokio;
 18use language::LanguageRegistry;
 19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 20use node_runtime::{NodeBinaryOptions, NodeRuntime};
 21use project::Project;
 22use project::project_settings::ProjectSettings;
 23use prompt_store::PromptBuilder;
 24use release_channel::AppVersion;
 25use reqwest_client::ReqwestClient;
 26use settings::{Settings, SettingsStore};
 27use std::env;
 28use std::path::{Path, PathBuf};
 29use std::sync::Arc;
 30use util::ResultExt as _;
 31
 32#[derive(Parser, Debug)]
 33#[command(name = "eval", disable_version_flag = true)]
 34struct Args {
 35    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 36    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 37    examples: Vec<String>,
 38    /// Model to use (default: "claude-3-7-sonnet-latest")
 39    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 40    model: String,
 41    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 42    languages: Vec<String>,
 43    /// How many times to run each example. Note that this is currently not very efficient as N
 44    /// worktrees will be created for the examples.
 45    #[arg(long, default_value = "1")]
 46    repetitions: u32,
 47    /// How many times to run the judge on each example run.
 48    #[arg(long, default_value = "3")]
 49    judge_repetitions: u32,
 50    /// Maximum number of examples to run concurrently.
 51    #[arg(long, default_value = "10")]
 52    concurrency: usize,
 53}
 54
 55fn main() {
 56    env_logger::init();
 57
 58    let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 59    let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 60    let session_id = uuid::Uuid::new_v4().to_string();
 61    let run_timestamp = chrono::Local::now().format("%Y-%m-%d_%H-%M-%S");
 62    let run_id = match env::var("GITHUB_RUN_ID") {
 63        Ok(run_id) => format!("github/{}", run_id),
 64        Err(_) => format!("local/{}", run_timestamp),
 65    };
 66
 67    let root_dir = Path::new(std::env!("CARGO_MANIFEST_DIR"))
 68        .parent()
 69        .unwrap()
 70        .parent()
 71        .unwrap();
 72    let eval_crate_dir = root_dir.join("crates/eval");
 73    let repos_dir = eval_crate_dir.join("repos");
 74    let worktrees_dir = eval_crate_dir.join("worktrees");
 75    let examples_dir = eval_crate_dir.join("examples");
 76    let runs_dir = eval_crate_dir.join("runs");
 77    let run_dir = runs_dir.join(format!("{}", run_timestamp));
 78    std::fs::create_dir_all(&run_dir).unwrap();
 79    std::fs::create_dir_all(&repos_dir).unwrap();
 80    std::fs::create_dir_all(&worktrees_dir).unwrap();
 81    std::fs::create_dir_all(&examples_dir).unwrap();
 82    std::fs::create_dir_all(&paths::config_dir()).unwrap();
 83
 84    let zed_commit_sha = commit_sha_for_path(root_dir);
 85    let zed_branch_name = git_branch_for_path(root_dir);
 86    let args = Args::parse();
 87    let all_available_examples = list_all_examples(&examples_dir).unwrap();
 88
 89    let example_paths = all_available_examples
 90        .iter()
 91        .filter_map(|example_path| {
 92            let name = example_path.file_name()?.to_string_lossy();
 93            if args.examples.is_empty()
 94                || args
 95                    .examples
 96                    .iter()
 97                    .any(|name_substring| name.contains(name_substring))
 98            {
 99                Some(example_path.clone())
100            } else {
101                None
102            }
103        })
104        .collect::<Vec<_>>();
105
106    let http_client = Arc::new(ReqwestClient::new());
107    let app = Application::headless().with_http_client(http_client.clone());
108
109    app.run(move |cx| {
110        let app_state = init(cx);
111
112        let telemetry = app_state.client.telemetry();
113        telemetry.start(system_id, installation_id, session_id, cx);
114
115        let enable_telemetry = env::var("ZED_EVAL_TELEMETRY").map_or(false, |value| value == "1")
116            && telemetry.has_checksum_seed();
117        if enable_telemetry {
118            println!("Telemetry enabled");
119            telemetry::event!(
120                "Agent Eval Started",
121                zed_commit_sha = zed_commit_sha,
122                zed_branch_name = zed_branch_name,
123                run_id = run_id,
124            );
125        }
126
127        let mut cumulative_tool_metrics = ToolMetrics::default();
128
129        let model_registry = LanguageModelRegistry::read_global(cx);
130        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
131        let model_provider_id = model.provider_id();
132        let model_provider = model_registry.provider(&model_provider_id).unwrap();
133
134        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
135            registry.set_default_model(
136                Some(ConfiguredModel {
137                    provider: model_provider.clone(),
138                    model: model.clone(),
139                }),
140                cx,
141            );
142        });
143
144        let authenticate_task = model_provider.authenticate(cx);
145
146        cx.spawn(async move |cx| {
147            authenticate_task.await.unwrap();
148
149            let mut examples = Vec::new();
150
151            const COLORS: [&str; 12] = [
152                "\x1b[31m", // Red
153                "\x1b[32m", // Green
154                "\x1b[33m", // Yellow
155                "\x1b[34m", // Blue
156                "\x1b[35m", // Magenta
157                "\x1b[36m", // Cyan
158                "\x1b[91m", // Bright Red
159                "\x1b[92m", // Bright Green
160                "\x1b[93m", // Bright Yellow
161                "\x1b[94m", // Bright Blue
162                "\x1b[95m", // Bright Magenta
163                "\x1b[96m", // Bright Cyan
164            ];
165
166            let mut max_name_width = 0;
167            let mut skipped = Vec::new();
168
169            for example_path in &example_paths {
170                let example = Example::load_from_directory(
171                    example_path,
172                    &run_dir,
173                    &worktrees_dir,
174                    &repos_dir,
175                )?;
176
177                if !example
178                    .base
179                    .language_extension
180                    .as_ref()
181                    .map_or(false, |lang| args.languages.contains(lang))
182                {
183                    skipped.push(example.name);
184                    continue;
185                }
186
187                // TODO: This creates a worktree per repetition. Ideally these examples should
188                // either be run sequentially on the same worktree, or reuse worktrees when there
189                // are more examples to run than the concurrency limit.
190                for repetition_number in 0..args.repetitions {
191                    let mut example = example.clone();
192                    example.set_repetition_number(repetition_number);
193
194                    let name_len = example.name.len();
195                    if name_len > max_name_width {
196                        max_name_width = example.name.len();
197                    }
198
199                    examples.push(example);
200                }
201            }
202
203            println!("Skipped examples: {}\n", skipped.join(", "));
204
205            if examples.is_empty() {
206                eprintln!("Filter matched no examples");
207                return cx.update(|cx| cx.quit());
208            }
209
210            let mut repo_urls = HashSet::default();
211            let mut clone_tasks = Vec::new();
212
213            for (i, example) in examples.iter_mut().enumerate() {
214                let color = COLORS[i % COLORS.len()].to_string();
215                example.set_log_prefix_style(&color, max_name_width);
216
217                println!(
218                    "{}Logging to: {}",
219                    example.log_prefix,
220                    example.example_output_directory().display()
221                );
222
223                let repo_url = example.base.url.clone();
224                if repo_urls.insert(repo_url.clone()) {
225                    let repo_path = example.repo_path.clone();
226
227                    if !repo_path.join(".git").is_dir() {
228                        println!(
229                            "{:<width$} < {}",
230                            "↓ Cloning",
231                            repo_url,
232                            width = max_name_width
233                        );
234
235                        let git_task = cx.spawn(async move |_cx| {
236                            std::fs::create_dir_all(&repo_path)?;
237                            run_git(&repo_path, &["init"]).await?;
238                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
239                        });
240
241                        clone_tasks.push(git_task);
242                    } else {
243                        println!(
244                            "{:<width$}  < {}",
245                            "✔︎ Already cloned",
246                            repo_url,
247                            width = max_name_width
248                        );
249
250                        let actual_origin =
251                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
252                        if actual_origin != repo_url {
253                            return Err(anyhow!(
254                                "remote origin {} does not match expected origin {}",
255                                actual_origin,
256                                repo_url,
257                            ));
258                        }
259                    }
260                }
261            }
262
263            future::join_all(clone_tasks).await;
264
265            for example in examples.iter_mut() {
266                example.setup().await?;
267            }
268
269            let judge_repetitions = args.judge_repetitions;
270            let concurrency = args.concurrency;
271
272            let tasks = examples.iter().map(|example| {
273                let app_state = app_state.clone();
274                let model = model.clone();
275                let example = example.clone();
276                let zed_commit_sha = zed_commit_sha.clone();
277                let zed_branch_name = zed_branch_name.clone();
278                let run_id = run_id.clone();
279                cx.spawn(async move |cx| {
280                    let result = async {
281                        let run_output = cx
282                            .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
283                            .await?;
284                        let judge_tasks = (0..judge_repetitions).map(|round| {
285                            run_judge_repetition(
286                                example.clone(),
287                                model.clone(),
288                                &zed_commit_sha,
289                                &zed_branch_name,
290                                &run_id,
291                                &run_output,
292                                round,
293                                enable_telemetry,
294                                cx,
295                            )
296                        });
297                        let judge_outputs = future::join_all(judge_tasks).await;
298                        anyhow::Ok((run_output, judge_outputs))
299                    }
300                    .await;
301                    (example, result)
302                })
303            });
304
305            let results = futures::stream::iter(tasks)
306                .buffer_unordered(concurrency)
307                .collect::<Vec<_>>()
308                .await;
309
310            println!("\n\n");
311            print_header("EVAL RESULTS");
312
313            let mut diff_scores = Vec::new();
314            let mut thread_scores = Vec::new();
315            let mut error_count = 0;
316
317            for (example, result) in results {
318                print_header(&example.name);
319
320                match result {
321                    Err(err) => {
322                        println!("💥 {}{:?}", example.log_prefix, err);
323                        error_count += 1;
324                    }
325                    Ok((run_output, judge_results)) => {
326                        cumulative_tool_metrics.merge(&run_output.tool_metrics);
327
328                        println!("┌───────┬──────┬────────┐");
329                        println!("│ Judge │ Diff │ Thread │");
330                        println!("├───────┼──────┼────────┤");
331
332                        for (i, judge_result) in judge_results.iter().enumerate() {
333                            match judge_result {
334                                Ok(judge_output) => {
335                                    let diff_score = judge_output.diff.score;
336                                    diff_scores.push(diff_score);
337
338                                    let thread_display = if let Some(thread) = &judge_output.thread
339                                    {
340                                        let thread_score = thread.score;
341                                        thread_scores.push(thread_score);
342                                        format!("{}", thread_score)
343                                    } else {
344                                        "N/A".to_string()
345                                    };
346
347                                    println!(
348                                        "|{:^7}{:^6}{:^8}",
349                                        i + 1,
350                                        diff_score,
351                                        thread_display
352                                    );
353                                }
354                                Err(err) => {
355                                    println!("|{:^7}{:^6}{:^8}{:?}", i + 1, "N/A", "N/A", err);
356                                }
357                            }
358                        }
359
360                        println!("└───────┴──────┴────────┘");
361
362                        println!("{}", run_output.tool_metrics);
363                    }
364                }
365                println!(
366                    "{}    > {}",
367                    " ".repeat(max_name_width),
368                    example.example_output_directory().display()
369                );
370            }
371
372            let diff_score_count = diff_scores.len();
373            let average_diff_score = diff_scores
374                .into_iter()
375                .map(|score| score as f32)
376                .sum::<f32>()
377                / (diff_score_count as f32);
378
379            if error_count > 0 {
380                println!("\n{error_count} examples failed to run!");
381            }
382
383            if diff_score_count > 0 {
384                println!("\nAverage code diff score: {average_diff_score}");
385            }
386
387            let thread_score_count = thread_scores.len();
388
389            // We might have gotten no thread scores if we weren't asked to judge the thread.
390            if thread_score_count > 0 {
391                let average_thread_score = thread_scores
392                    .into_iter()
393                    .map(|score| score as f32)
394                    .sum::<f32>()
395                    / (thread_score_count as f32);
396
397                if diff_score_count > 0 {
398                    println!("\nAverage thread score: {average_thread_score}");
399                }
400            }
401
402            print_header("CUMULATIVE TOOL METRICS");
403            println!("{}", cumulative_tool_metrics);
404
405            app_state.client.telemetry().flush_events().await;
406
407            cx.update(|cx| cx.quit())
408        })
409        .detach_and_log_err(cx);
410    });
411}
412
413fn list_all_examples(examples_dir: &Path) -> Result<Vec<PathBuf>> {
414    let path = std::fs::canonicalize(examples_dir).unwrap();
415    let entries = std::fs::read_dir(path).unwrap();
416    let mut result_paths = Vec::new();
417    for entry in entries {
418        let entry = entry?;
419        let path = entry.path();
420        if path.is_dir() {
421            result_paths.push(path);
422        }
423    }
424    Ok(result_paths)
425}
426
427/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
428pub struct AgentAppState {
429    pub languages: Arc<LanguageRegistry>,
430    pub client: Arc<Client>,
431    pub user_store: Entity<UserStore>,
432    pub fs: Arc<dyn fs::Fs>,
433    pub node_runtime: NodeRuntime,
434
435    // Additional fields not present in `workspace::AppState`.
436    pub prompt_builder: Arc<PromptBuilder>,
437}
438
439pub fn init(cx: &mut App) -> Arc<AgentAppState> {
440    release_channel::init(SemanticVersion::default(), cx);
441    gpui_tokio::init(cx);
442
443    let mut settings_store = SettingsStore::new(cx);
444    settings_store
445        .set_default_settings(settings::default_settings().as_ref(), cx)
446        .unwrap();
447    cx.set_global(settings_store);
448    client::init_settings(cx);
449
450    // Set User-Agent so we can download language servers from GitHub
451    let user_agent = format!(
452        "Zed/{} ({}; {})",
453        AppVersion::global(cx),
454        std::env::consts::OS,
455        std::env::consts::ARCH
456    );
457    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
458    let proxy_url = proxy_str
459        .as_ref()
460        .and_then(|input| input.parse::<Uri>().ok())
461        .or_else(read_proxy_from_env);
462    let http = {
463        let _guard = Tokio::handle(cx).enter();
464
465        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
466            .expect("could not start HTTP client")
467    };
468    cx.set_http_client(Arc::new(http));
469
470    Project::init_settings(cx);
471
472    let client = Client::production(cx);
473    cx.set_http_client(client.http_client().clone());
474
475    let git_binary_path = None;
476    let fs = Arc::new(RealFs::new(
477        git_binary_path,
478        cx.background_executor().clone(),
479    ));
480
481    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
482    languages.set_language_server_download_dir(paths::languages_dir().clone());
483    let languages = Arc::new(languages);
484
485    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
486
487    extension::init(cx);
488
489    let (tx, rx) = async_watch::channel(None);
490    cx.observe_global::<SettingsStore>(move |cx| {
491        let settings = &ProjectSettings::get_global(cx).node;
492        let options = NodeBinaryOptions {
493            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
494            allow_binary_download: true,
495            use_paths: settings.path.as_ref().map(|node_path| {
496                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
497                let npm_path = settings
498                    .npm_path
499                    .as_ref()
500                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
501                (
502                    node_path.clone(),
503                    npm_path.unwrap_or_else(|| {
504                        let base_path = PathBuf::new();
505                        node_path.parent().unwrap_or(&base_path).join("npm")
506                    }),
507                )
508            }),
509        };
510        tx.send(Some(options)).log_err();
511    })
512    .detach();
513    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
514
515    let extension_host_proxy = ExtensionHostProxy::global(cx);
516
517    language::init(cx);
518    language_extension::init(extension_host_proxy.clone(), languages.clone());
519    language_model::init(client.clone(), cx);
520    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
521    languages::init(languages.clone(), node_runtime.clone(), cx);
522    assistant_tools::init(client.http_client().clone(), cx);
523    context_server::init(cx);
524    prompt_store::init(cx);
525    let stdout_is_a_pty = false;
526    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
527    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
528
529    SettingsStore::update_global(cx, |store, cx| {
530        store.set_user_settings(include_str!("../runner_settings.json"), cx)
531    })
532    .unwrap();
533
534    Arc::new(AgentAppState {
535        languages,
536        client,
537        user_store,
538        fs,
539        node_runtime,
540        prompt_builder,
541    })
542}
543
544pub fn find_model(
545    model_name: &str,
546    model_registry: &LanguageModelRegistry,
547    cx: &App,
548) -> anyhow::Result<Arc<dyn LanguageModel>> {
549    let model = model_registry
550        .available_models(cx)
551        .find(|model| model.id().0 == model_name);
552
553    let Some(model) = model else {
554        return Err(anyhow!(
555            "No language model named {} was available. Available models: {}",
556            model_name,
557            model_registry
558                .available_models(cx)
559                .map(|model| model.id().0.clone())
560                .collect::<Vec<_>>()
561                .join(", ")
562        ));
563    };
564
565    Ok(model)
566}
567
568pub fn commit_sha_for_path(repo_path: &Path) -> String {
569    futures::executor::block_on(run_git(repo_path, &["rev-parse", "HEAD"])).unwrap()
570}
571
572pub fn git_branch_for_path(repo_path: &Path) -> String {
573    match std::env::var("GITHUB_REF_NAME") {
574        Ok(branch) => branch,
575        Err(_) => {
576            futures::executor::block_on(run_git(repo_path, &["rev-parse", "--abbrev-ref", "HEAD"]))
577                .unwrap_or_else(|_| "unknown".to_string())
578        }
579    }
580}
581
582async fn run_judge_repetition(
583    example: Example,
584    model: Arc<dyn LanguageModel>,
585    zed_commit_sha: &str,
586    zed_branch_name: &str,
587    run_id: &str,
588    run_output: &RunOutput,
589    round: u32,
590    enable_telemetry: bool,
591    cx: &AsyncApp,
592) -> Result<JudgeOutput> {
593    let judge_output = example.judge(model.clone(), &run_output, round, cx).await;
594
595    let diff_evaluation;
596    let thread_diff_evaluation;
597    if let Ok(output) = judge_output.as_ref() {
598        diff_evaluation = Some(output.diff.clone());
599        thread_diff_evaluation = output.thread.clone();
600    } else {
601        diff_evaluation = None;
602        thread_diff_evaluation = None;
603    }
604
605    if enable_telemetry {
606        telemetry::event!(
607            "Agent Example Evaluated",
608            zed_commit_sha = zed_commit_sha,
609            zed_branch_name = zed_branch_name,
610            run_id = run_id,
611            example_name = example.name.clone(),
612            round = round,
613            diff_evaluation = diff_evaluation,
614            thread_evaluation = thread_diff_evaluation,
615            tool_metrics = run_output.tool_metrics,
616            response_count = run_output.response_count,
617            token_usage = run_output.token_usage,
618            model = model.telemetry_id(),
619            model_provider = model.provider_id().to_string(),
620            repository_url = example.base.url.clone(),
621            repository_revision = example.base.revision.clone(),
622            diagnostics_before = run_output.diagnostics_before,
623            diagnostics_after = run_output.diagnostics_after,
624        );
625    }
626
627    judge_output
628}
629
630fn print_header(header: &str) {
631    println!("\n========================================");
632    println!("{:^40}", header);
633    println!("========================================\n");
634}