eval.rs

  1mod example;
  2mod ids;
  3mod tool_metrics;
  4
  5pub(crate) use example::*;
  6pub(crate) use tool_metrics::*;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use client::{Client, ProxySettings, UserStore};
 12use collections::HashSet;
 13use extension::ExtensionHostProxy;
 14use futures::{StreamExt, future};
 15use gpui::http_client::{Uri, read_proxy_from_env};
 16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 17use gpui_tokio::Tokio;
 18use language::LanguageRegistry;
 19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 20use node_runtime::{NodeBinaryOptions, NodeRuntime};
 21use project::Project;
 22use project::project_settings::ProjectSettings;
 23use prompt_store::PromptBuilder;
 24use release_channel::AppVersion;
 25use reqwest_client::ReqwestClient;
 26use settings::{Settings, SettingsStore};
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use std::usize;
 30use util::ResultExt as _;
 31
 32pub const RUNS_DIR: &str = "./crates/eval/runs";
 33
 34#[derive(Parser, Debug)]
 35#[command(name = "eval", disable_version_flag = true)]
 36struct Args {
 37    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 38    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 39    examples: Vec<String>,
 40    /// Model to use (default: "claude-3-7-sonnet-latest")
 41    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 42    model: String,
 43    #[arg(long, value_delimiter = ',', default_value = "rs,ts")]
 44    languages: Vec<String>,
 45    /// How many times to run each example. Note that this is currently not very efficient as N
 46    /// worktrees will be created for the examples.
 47    #[arg(long, default_value = "1")]
 48    repetitions: u32,
 49    /// How many times to run the judge on each example run.
 50    #[arg(long, default_value = "3")]
 51    judge_repetitions: u32,
 52    /// Maximum number of examples to run concurrently.
 53    #[arg(long, default_value = "10")]
 54    concurrency: usize,
 55}
 56
 57fn main() {
 58    env_logger::init();
 59
 60    let args = Args::parse();
 61    let all_available_examples = list_all_examples().unwrap();
 62
 63    let example_paths = all_available_examples
 64        .iter()
 65        .filter_map(|example_path| {
 66            let name = example_path.file_name()?.to_string_lossy();
 67            if args.examples.is_empty()
 68                || args
 69                    .examples
 70                    .iter()
 71                    .any(|name_substring| name.contains(name_substring))
 72            {
 73                Some(example_path.clone())
 74            } else {
 75                None
 76            }
 77        })
 78        .collect::<Vec<_>>();
 79
 80    let http_client = Arc::new(ReqwestClient::new());
 81    let app = Application::headless().with_http_client(http_client.clone());
 82
 83    app.run(move |cx| {
 84        let app_state = init(cx);
 85
 86        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 87        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 88        let session_id = uuid::Uuid::new_v4().to_string();
 89
 90        app_state
 91            .client
 92            .telemetry()
 93            .start(system_id, installation_id, session_id, cx);
 94
 95        let mut cumulative_tool_metrics = ToolMetrics::default();
 96
 97        let model_registry = LanguageModelRegistry::read_global(cx);
 98        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
 99        let model_provider_id = model.provider_id();
100        let model_provider = model_registry.provider(&model_provider_id).unwrap();
101
102        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
103            registry.set_default_model(
104                Some(ConfiguredModel {
105                    provider: model_provider.clone(),
106                    model: model.clone(),
107                }),
108                cx,
109            );
110        });
111
112        let authenticate_task = model_provider.authenticate(cx);
113
114        cx.spawn(async move |cx| {
115            authenticate_task.await.unwrap();
116
117            std::fs::create_dir_all(REPOS_DIR)?;
118            std::fs::create_dir_all(WORKTREES_DIR)?;
119
120            let run_dir = Path::new(RUNS_DIR).join(format!(
121                "{}",
122                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
123            ));
124            std::fs::create_dir_all(&run_dir)?;
125
126            let mut examples = Vec::new();
127
128            const COLORS: [&str; 12] = [
129                "\x1b[31m", // Red
130                "\x1b[32m", // Green
131                "\x1b[33m", // Yellow
132                "\x1b[34m", // Blue
133                "\x1b[35m", // Magenta
134                "\x1b[36m", // Cyan
135                "\x1b[91m", // Bright Red
136                "\x1b[92m", // Bright Green
137                "\x1b[93m", // Bright Yellow
138                "\x1b[94m", // Bright Blue
139                "\x1b[95m", // Bright Magenta
140                "\x1b[96m", // Bright Cyan
141            ];
142
143            let mut max_name_width = 0;
144            let mut skipped = Vec::new();
145
146            for example_path in &example_paths {
147                let example = Example::load_from_directory(example_path, &run_dir)?;
148
149                if !example
150                    .base
151                    .language_extension
152                    .as_ref()
153                    .map_or(false, |lang| args.languages.contains(lang))
154                {
155                    skipped.push(example.name);
156                    continue;
157                }
158
159                // TODO: This creates a worktree per repetition. Ideally these examples should
160                // either be run sequentially on the same worktree, or reuse worktrees when there
161                // are more examples to run than the concurrency limit.
162                for repetition_number in 0..args.repetitions {
163                    let mut example = example.clone();
164                    example.set_repetition_number(repetition_number);
165
166                    let name_len = example.name.len();
167                    if name_len > max_name_width {
168                        max_name_width = example.name.len();
169                    }
170
171                    examples.push(example);
172                }
173            }
174
175            println!("Skipped examples: {}\n", skipped.join(", "));
176
177            if examples.is_empty() {
178                eprintln!("Filter matched no examples");
179                return cx.update(|cx| cx.quit());
180            }
181
182            let mut repo_urls = HashSet::default();
183            let mut clone_tasks = Vec::new();
184
185            for (i, example) in examples.iter_mut().enumerate() {
186                let color = COLORS[i % COLORS.len()].to_string();
187                example.set_log_prefix_style(&color, max_name_width);
188
189                println!(
190                    "{}Logging to: {}",
191                    example.log_prefix,
192                    example.example_output_directory().display()
193                );
194
195                let repo_url = example.base.url.clone();
196                if repo_urls.insert(repo_url.clone()) {
197                    let repo_path = repo_path_for_url(&repo_url);
198
199                    if !repo_path.join(".git").is_dir() {
200                        println!(
201                            "{:<width$} < {}",
202                            "↓ Cloning",
203                            repo_url,
204                            width = max_name_width
205                        );
206
207                        let git_task = cx.spawn(async move |_cx| {
208                            std::fs::create_dir_all(&repo_path)?;
209                            run_git(&repo_path, &["init"]).await?;
210                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
211                        });
212
213                        clone_tasks.push(git_task);
214                    } else {
215                        println!(
216                            "{:<width$}  < {}",
217                            "✔︎ Already cloned",
218                            repo_url,
219                            width = max_name_width
220                        );
221
222                        let actual_origin =
223                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
224                        if actual_origin != repo_url {
225                            return Err(anyhow!(
226                                "remote origin {} does not match expected origin {}",
227                                actual_origin,
228                                repo_url,
229                            ));
230                        }
231                    }
232                }
233            }
234
235            future::join_all(clone_tasks).await;
236
237            for example in examples.iter_mut() {
238                example.setup().await?;
239            }
240
241            let judge_repetitions = args.judge_repetitions;
242            let concurrency = args.concurrency;
243
244            let tasks = examples.iter().map(|example| {
245                let app_state = app_state.clone();
246                let model = model.clone();
247                let example = example.clone();
248                cx.spawn(async move |cx| {
249                    let result = async {
250                        let run_output = cx
251                            .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
252                            .await?;
253                        let judge_tasks = (0..judge_repetitions).map(|round| {
254                            run_judge_repetition(
255                                example.clone(),
256                                model.clone(),
257                                &run_output,
258                                round,
259                                cx,
260                            )
261                        });
262                        let judge_outputs = future::join_all(judge_tasks).await;
263                        anyhow::Ok((run_output, judge_outputs))
264                    }
265                    .await;
266                    (example, result)
267                })
268            });
269
270            let results = futures::stream::iter(tasks)
271                .buffer_unordered(concurrency)
272                .collect::<Vec<_>>()
273                .await;
274
275            println!("\n\n");
276            print_header("EVAL RESULTS");
277
278            let mut diff_scores = Vec::new();
279            let mut thread_scores = Vec::new();
280            let mut error_count = 0;
281
282            for (example, result) in results {
283                print_header(&example.name);
284
285                match result {
286                    Err(err) => {
287                        println!("💥 {}{:?}", example.log_prefix, err);
288                        error_count += 1;
289                    }
290                    Ok((run_output, judge_results)) => {
291                        cumulative_tool_metrics.merge(&run_output.tool_metrics);
292
293                        println!("┌───────┬──────┬────────┐");
294                        println!("│ Judge │ Diff │ Thread │");
295                        println!("├───────┼──────┼────────┤");
296
297                        for (i, judge_result) in judge_results.iter().enumerate() {
298                            match judge_result {
299                                Ok(judge_output) => {
300                                    let diff_score = judge_output.diff.score;
301                                    diff_scores.push(diff_score);
302
303                                    let thread_display = if let Some(thread) = &judge_output.thread
304                                    {
305                                        let thread_score = thread.score;
306                                        thread_scores.push(thread_score);
307                                        format!("{}", thread_score)
308                                    } else {
309                                        "N/A".to_string()
310                                    };
311
312                                    println!(
313                                        "|{:^7}{:^6}{:^8}",
314                                        i + 1,
315                                        diff_score,
316                                        thread_display
317                                    );
318                                }
319                                Err(err) => {
320                                    println!("|{:^7}{:^6}{:^8}{:?}", i + 1, "N/A", "N/A", err);
321                                }
322                            }
323                        }
324
325                        println!("└───────┴──────┴────────┘");
326
327                        println!("{}", run_output.tool_metrics);
328                    }
329                }
330                println!(
331                    "{}    > {}",
332                    " ".repeat(max_name_width),
333                    example.example_output_directory().display()
334                );
335            }
336
337            let diff_score_count = diff_scores.len();
338            let average_diff_score = diff_scores
339                .into_iter()
340                .map(|score| score as f32)
341                .sum::<f32>()
342                / (diff_score_count as f32);
343
344            if error_count > 0 {
345                println!("\n{error_count} examples failed to run!");
346            }
347
348            if diff_score_count > 0 {
349                println!("\nAverage code diff score: {average_diff_score}");
350            }
351
352            let thread_score_count = thread_scores.len();
353
354            // We might have gotten no thread scores if we weren't asked to judge the thread.
355            if thread_score_count > 0 {
356                let average_thread_score = thread_scores
357                    .into_iter()
358                    .map(|score| score as f32)
359                    .sum::<f32>()
360                    / (thread_score_count as f32);
361
362                if diff_score_count > 0 {
363                    println!("\nAverage thread score: {average_thread_score}");
364                }
365            }
366
367            print_header("CUMULATIVE TOOL METRICS");
368            println!("{}", cumulative_tool_metrics);
369
370            std::thread::sleep(std::time::Duration::from_secs(2));
371
372            app_state.client.telemetry().flush_events();
373
374            cx.update(|cx| cx.quit())
375        })
376        .detach_and_log_err(cx);
377    });
378}
379
380fn list_all_examples() -> Result<Vec<PathBuf>> {
381    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
382    let entries = std::fs::read_dir(path).unwrap();
383    let mut result_paths = Vec::new();
384    for entry in entries {
385        let entry = entry?;
386        let path = entry.path();
387        if path.is_dir() {
388            result_paths.push(path);
389        }
390    }
391    Ok(result_paths)
392}
393
394/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
395pub struct AgentAppState {
396    pub languages: Arc<LanguageRegistry>,
397    pub client: Arc<Client>,
398    pub user_store: Entity<UserStore>,
399    pub fs: Arc<dyn fs::Fs>,
400    pub node_runtime: NodeRuntime,
401
402    // Additional fields not present in `workspace::AppState`.
403    pub prompt_builder: Arc<PromptBuilder>,
404}
405
406pub fn init(cx: &mut App) -> Arc<AgentAppState> {
407    release_channel::init(SemanticVersion::default(), cx);
408    gpui_tokio::init(cx);
409
410    let mut settings_store = SettingsStore::new(cx);
411    settings_store
412        .set_default_settings(settings::default_settings().as_ref(), cx)
413        .unwrap();
414    cx.set_global(settings_store);
415    client::init_settings(cx);
416
417    // Set User-Agent so we can download language servers from GitHub
418    let user_agent = format!(
419        "Zed/{} ({}; {})",
420        AppVersion::global(cx),
421        std::env::consts::OS,
422        std::env::consts::ARCH
423    );
424    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
425    let proxy_url = proxy_str
426        .as_ref()
427        .and_then(|input| input.parse::<Uri>().ok())
428        .or_else(read_proxy_from_env);
429    let http = {
430        let _guard = Tokio::handle(cx).enter();
431
432        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
433            .expect("could not start HTTP client")
434    };
435    cx.set_http_client(Arc::new(http));
436
437    Project::init_settings(cx);
438
439    let client = Client::production(cx);
440    cx.set_http_client(client.http_client().clone());
441
442    let git_binary_path = None;
443    let fs = Arc::new(RealFs::new(
444        git_binary_path,
445        cx.background_executor().clone(),
446    ));
447
448    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
449    languages.set_language_server_download_dir(paths::languages_dir().clone());
450    let languages = Arc::new(languages);
451
452    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
453
454    extension::init(cx);
455
456    let (tx, rx) = async_watch::channel(None);
457    cx.observe_global::<SettingsStore>(move |cx| {
458        let settings = &ProjectSettings::get_global(cx).node;
459        let options = NodeBinaryOptions {
460            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
461            allow_binary_download: true,
462            use_paths: settings.path.as_ref().map(|node_path| {
463                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
464                let npm_path = settings
465                    .npm_path
466                    .as_ref()
467                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
468                (
469                    node_path.clone(),
470                    npm_path.unwrap_or_else(|| {
471                        let base_path = PathBuf::new();
472                        node_path.parent().unwrap_or(&base_path).join("npm")
473                    }),
474                )
475            }),
476        };
477        tx.send(Some(options)).log_err();
478    })
479    .detach();
480    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
481
482    let extension_host_proxy = ExtensionHostProxy::global(cx);
483
484    language::init(cx);
485    language_extension::init(extension_host_proxy.clone(), languages.clone());
486    language_model::init(client.clone(), cx);
487    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
488    languages::init(languages.clone(), node_runtime.clone(), cx);
489    assistant_tools::init(client.http_client().clone(), cx);
490    context_server::init(cx);
491    prompt_store::init(cx);
492    let stdout_is_a_pty = false;
493    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
494    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
495
496    SettingsStore::update_global(cx, |store, cx| {
497        store.set_user_settings(include_str!("../runner_settings.json"), cx)
498    })
499    .unwrap();
500
501    Arc::new(AgentAppState {
502        languages,
503        client,
504        user_store,
505        fs,
506        node_runtime,
507        prompt_builder,
508    })
509}
510
511pub fn find_model(
512    model_name: &str,
513    model_registry: &LanguageModelRegistry,
514    cx: &App,
515) -> anyhow::Result<Arc<dyn LanguageModel>> {
516    let model = model_registry
517        .available_models(cx)
518        .find(|model| model.id().0 == model_name);
519
520    let Some(model) = model else {
521        return Err(anyhow!(
522            "No language model named {} was available. Available models: {}",
523            model_name,
524            model_registry
525                .available_models(cx)
526                .map(|model| model.id().0.clone())
527                .collect::<Vec<_>>()
528                .join(", ")
529        ));
530    };
531
532    Ok(model)
533}
534
535pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
536    (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
537}
538
539pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
540    futures::executor::block_on(async {
541        get_current_commit_id(repo_path).await.unwrap_or_default()
542    })
543}
544
545async fn run_judge_repetition(
546    example: Example,
547    model: Arc<dyn LanguageModel>,
548    run_output: &RunOutput,
549    round: u32,
550    cx: &AsyncApp,
551) -> Result<JudgeOutput> {
552    let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
553
554    if let Ok(judge_output) = &judge_result {
555        let cohort_id = example
556            .run_directory_path
557            .file_name()
558            .map(|name| name.to_string_lossy().to_string())
559            .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
560
561        let path = std::path::Path::new(".");
562        let commit_id = get_current_commit_id(path).await.unwrap_or_default();
563
564        if let Some(thread) = &judge_output.thread {
565            telemetry::event!(
566                "Agent Eval Completed",
567                cohort_id = cohort_id,
568                example_name = example.name.clone(),
569                round = round,
570                diff_score = judge_output.diff.score,
571                diff_analysis = judge_output.diff.analysis,
572                thread_score = thread.score,
573                thread_analysis = thread.analysis,
574                tool_metrics = run_output.tool_metrics,
575                response_count = run_output.response_count,
576                token_usage = run_output.token_usage,
577                model = model.telemetry_id(),
578                model_provider = model.provider_id().to_string(),
579                repository_url = example.base.url.clone(),
580                repository_revision = example.base.revision.clone(),
581                diagnostics_before = run_output.diagnostics_before,
582                diagnostics_after = run_output.diagnostics_after,
583                commit_id = commit_id
584            );
585        } else {
586            telemetry::event!(
587                "Agent Eval Completed",
588                cohort_id = cohort_id,
589                example_name = example.name.clone(),
590                round = round,
591                diff_score = judge_output.diff.score,
592                diff_analysis = judge_output.diff.analysis,
593                tool_metrics = run_output.tool_metrics,
594                response_count = run_output.response_count,
595                token_usage = run_output.token_usage,
596                model = model.telemetry_id(),
597                model_provider = model.provider_id().to_string(),
598                repository_url = example.base.url.clone(),
599                repository_revision = example.base.revision.clone(),
600                diagnostics_before = run_output.diagnostics_before,
601                diagnostics_after = run_output.diagnostics_after,
602                commit_id = commit_id
603            );
604        }
605    }
606
607    judge_result
608}
609
610fn print_header(header: &str) {
611    println!("\n========================================");
612    println!("{:^40}", header);
613    println!("========================================\n");
614}