eval.rs

  1mod example;
  2mod ids;
  3mod tool_metrics;
  4
  5pub(crate) use example::*;
  6pub(crate) use tool_metrics::*;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use client::{Client, ProxySettings, UserStore};
 12use collections::HashSet;
 13use extension::ExtensionHostProxy;
 14use futures::{StreamExt, future};
 15use gpui::http_client::{Uri, read_proxy_from_env};
 16use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 17use gpui_tokio::Tokio;
 18use language::LanguageRegistry;
 19use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 20use node_runtime::{NodeBinaryOptions, NodeRuntime};
 21use project::Project;
 22use project::project_settings::ProjectSettings;
 23use prompt_store::PromptBuilder;
 24use release_channel::AppVersion;
 25use reqwest_client::ReqwestClient;
 26use settings::{Settings, SettingsStore};
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use std::usize;
 30use util::ResultExt as _;
 31
 32pub const RUNS_DIR: &str = "./crates/eval/runs";
 33
 34#[derive(Parser, Debug)]
 35#[command(name = "eval", disable_version_flag = true)]
 36struct Args {
 37    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 38    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 39    examples: Vec<String>,
 40    /// Model to use (default: "claude-3-7-sonnet-latest")
 41    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 42    model: String,
 43    #[arg(long, value_delimiter = ',')]
 44    languages: Option<Vec<String>>,
 45    /// How many times to run each example. Note that this is currently not very efficient as N
 46    /// worktrees will be created for the examples.
 47    #[arg(long, default_value = "1")]
 48    repetitions: u32,
 49    /// How many times to run the judge on each example run.
 50    #[arg(long, default_value = "3")]
 51    judge_repetitions: u32,
 52    /// Maximum number of examples to run concurrently.
 53    #[arg(long, default_value = "10")]
 54    concurrency: usize,
 55}
 56
 57fn main() {
 58    env_logger::init();
 59
 60    let args = Args::parse();
 61    let all_available_examples = list_all_examples().unwrap();
 62    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 63
 64    let example_paths = all_available_examples
 65        .iter()
 66        .filter_map(|example_path| {
 67            let name = example_path.file_name()?.to_string_lossy();
 68            if args.examples.is_empty()
 69                || args
 70                    .examples
 71                    .iter()
 72                    .any(|name_substring| name.contains(name_substring))
 73            {
 74                Some(example_path.clone())
 75            } else {
 76                None
 77            }
 78        })
 79        .collect::<Vec<_>>();
 80
 81    let http_client = Arc::new(ReqwestClient::new());
 82    let app = Application::headless().with_http_client(http_client.clone());
 83
 84    app.run(move |cx| {
 85        let app_state = init(cx);
 86
 87        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 88        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 89        let session_id = uuid::Uuid::new_v4().to_string();
 90
 91        app_state
 92            .client
 93            .telemetry()
 94            .start(system_id, installation_id, session_id, cx);
 95
 96        let mut cumulative_tool_metrics = ToolMetrics::default();
 97
 98        let model_registry = LanguageModelRegistry::read_global(cx);
 99        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
100        let model_provider_id = model.provider_id();
101        let model_provider = model_registry.provider(&model_provider_id).unwrap();
102
103        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
104            registry.set_default_model(
105                Some(ConfiguredModel {
106                    provider: model_provider.clone(),
107                    model: model.clone(),
108                }),
109                cx,
110            );
111        });
112
113        let authenticate_task = model_provider.authenticate(cx);
114
115        cx.spawn(async move |cx| {
116            authenticate_task.await.unwrap();
117
118            std::fs::create_dir_all(REPOS_DIR)?;
119            std::fs::create_dir_all(WORKTREES_DIR)?;
120
121            let run_dir = Path::new(RUNS_DIR).join(format!(
122                "{}",
123                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
124            ));
125            std::fs::create_dir_all(&run_dir)?;
126
127            let mut examples = Vec::new();
128
129            const COLORS: [&str; 12] = [
130                "\x1b[31m", // Red
131                "\x1b[32m", // Green
132                "\x1b[33m", // Yellow
133                "\x1b[34m", // Blue
134                "\x1b[35m", // Magenta
135                "\x1b[36m", // Cyan
136                "\x1b[91m", // Bright Red
137                "\x1b[92m", // Bright Green
138                "\x1b[93m", // Bright Yellow
139                "\x1b[94m", // Bright Blue
140                "\x1b[95m", // Bright Magenta
141                "\x1b[96m", // Bright Cyan
142            ];
143
144            let mut max_name_width = 0;
145            let mut skipped = Vec::new();
146
147            for example_path in &example_paths {
148                let example = Example::load_from_directory(example_path, &run_dir)?;
149
150                if !example
151                    .base
152                    .language_extension
153                    .as_ref()
154                    .map_or(false, |lang| languages.contains(lang))
155                {
156                    skipped.push(example.name);
157                    continue;
158                }
159
160                // TODO: This creates a worktree per repetition. Ideally these examples should
161                // either be run sequentially on the same worktree, or reuse worktrees when there
162                // are more examples to run than the concurrency limit.
163                for repetition_number in 0..args.repetitions {
164                    let mut example = example.clone();
165                    example.set_repetition_number(repetition_number);
166
167                    let name_len = example.name.len();
168                    if name_len > max_name_width {
169                        max_name_width = example.name.len();
170                    }
171
172                    examples.push(example);
173                }
174            }
175
176            println!("Skipped examples: {}\n", skipped.join(", "));
177
178            if examples.is_empty() {
179                eprintln!("Filter matched no examples");
180                return cx.update(|cx| cx.quit());
181            }
182
183            let mut repo_urls = HashSet::default();
184            let mut clone_tasks = Vec::new();
185
186            for (i, example) in examples.iter_mut().enumerate() {
187                let color = COLORS[i % COLORS.len()].to_string();
188                example.set_log_prefix_style(&color, max_name_width);
189
190                println!(
191                    "{}Logging to: {}",
192                    example.log_prefix,
193                    example.example_output_directory().display()
194                );
195
196                let repo_url = example.base.url.clone();
197                if repo_urls.insert(repo_url.clone()) {
198                    let repo_path = repo_path_for_url(&repo_url);
199
200                    if !repo_path.join(".git").is_dir() {
201                        println!(
202                            "{:<width$} < {}",
203                            "↓ Cloning",
204                            repo_url,
205                            width = max_name_width
206                        );
207
208                        let git_task = cx.spawn(async move |_cx| {
209                            std::fs::create_dir_all(&repo_path)?;
210                            run_git(&repo_path, &["init"]).await?;
211                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
212                        });
213
214                        clone_tasks.push(git_task);
215                    } else {
216                        println!(
217                            "{:<width$}  < {}",
218                            "✔︎ Already cloned",
219                            repo_url,
220                            width = max_name_width
221                        );
222
223                        let actual_origin =
224                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
225                        if actual_origin != repo_url {
226                            return Err(anyhow!(
227                                "remote origin {} does not match expected origin {}",
228                                actual_origin,
229                                repo_url,
230                            ));
231                        }
232                    }
233                }
234            }
235
236            future::join_all(clone_tasks).await;
237
238            for example in examples.iter_mut() {
239                example.setup().await?;
240            }
241
242            let judge_repetitions = args.judge_repetitions;
243            let concurrency = args.concurrency;
244
245            let tasks = examples.iter().map(|example| {
246                let app_state = app_state.clone();
247                let model = model.clone();
248                let example = example.clone();
249                cx.spawn(async move |cx| {
250                    let result = async {
251                        let run_output = cx
252                            .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
253                            .await?;
254                        let judge_tasks = (0..judge_repetitions).map(|round| {
255                            run_judge_repetition(
256                                example.clone(),
257                                model.clone(),
258                                &run_output,
259                                round,
260                                cx,
261                            )
262                        });
263                        let judge_outputs = future::join_all(judge_tasks).await;
264                        anyhow::Ok((run_output, judge_outputs))
265                    }
266                    .await;
267                    (example, result)
268                })
269            });
270
271            let results = futures::stream::iter(tasks)
272                .buffer_unordered(concurrency)
273                .collect::<Vec<_>>()
274                .await;
275
276            println!("\n\n");
277            print_header("EVAL RESULTS");
278
279            let mut diff_scores = Vec::new();
280            let mut thread_scores = Vec::new();
281            let mut error_count = 0;
282
283            for (example, result) in results {
284                print_header(&example.name);
285
286                match result {
287                    Err(err) => {
288                        println!("💥 {}{:?}", example.log_prefix, err);
289                        error_count += 1;
290                    }
291                    Ok((run_output, judge_results)) => {
292                        cumulative_tool_metrics.merge(&run_output.tool_metrics);
293
294                        println!("┌───────┬──────┬────────┐");
295                        println!("│ Judge │ Diff │ Thread │");
296                        println!("├───────┼──────┼────────┤");
297
298                        for (i, judge_result) in judge_results.iter().enumerate() {
299                            match judge_result {
300                                Ok(judge_output) => {
301                                    let diff_score = judge_output.diff.score;
302                                    diff_scores.push(diff_score);
303
304                                    let thread_display = if let Some(thread) = &judge_output.thread
305                                    {
306                                        let thread_score = thread.score;
307                                        thread_scores.push(thread_score);
308                                        format!("{}", thread_score)
309                                    } else {
310                                        "N/A".to_string()
311                                    };
312
313                                    println!(
314                                        "|{:^7}{:^6}{:^8}",
315                                        i + 1,
316                                        diff_score,
317                                        thread_display
318                                    );
319                                }
320                                Err(err) => {
321                                    println!("|{:^7}{:^6}{:^8}{:?}", i + 1, "N/A", "N/A", err);
322                                }
323                            }
324                        }
325
326                        println!("└───────┴──────┴────────┘");
327
328                        println!("{}", run_output.tool_metrics);
329                    }
330                }
331                println!(
332                    "{}    > {}",
333                    " ".repeat(max_name_width),
334                    example.example_output_directory().display()
335                );
336            }
337
338            let diff_score_count = diff_scores.len();
339            let average_diff_score = diff_scores
340                .into_iter()
341                .map(|score| score as f32)
342                .sum::<f32>()
343                / (diff_score_count as f32);
344
345            if error_count > 0 {
346                println!("\n{error_count} examples failed to run!");
347            }
348
349            if diff_score_count > 0 {
350                println!("\nAverage code diff score: {average_diff_score}");
351            }
352
353            let thread_score_count = thread_scores.len();
354
355            // We might have gotten no thread scores if we weren't asked to judge the thread.
356            if thread_score_count > 0 {
357                let average_thread_score = thread_scores
358                    .into_iter()
359                    .map(|score| score as f32)
360                    .sum::<f32>()
361                    / (thread_score_count as f32);
362
363                if diff_score_count > 0 {
364                    println!("\nAverage thread score: {average_thread_score}");
365                }
366            }
367
368            print_header("CUMULATIVE TOOL METRICS");
369            println!("{}", cumulative_tool_metrics);
370
371            std::thread::sleep(std::time::Duration::from_secs(2));
372
373            app_state.client.telemetry().flush_events();
374
375            cx.update(|cx| cx.quit())
376        })
377        .detach_and_log_err(cx);
378    });
379}
380
381fn list_all_examples() -> Result<Vec<PathBuf>> {
382    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
383    let entries = std::fs::read_dir(path).unwrap();
384    let mut result_paths = Vec::new();
385    for entry in entries {
386        let entry = entry?;
387        let path = entry.path();
388        if path.is_dir() {
389            result_paths.push(path);
390        }
391    }
392    Ok(result_paths)
393}
394
395/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
396pub struct AgentAppState {
397    pub languages: Arc<LanguageRegistry>,
398    pub client: Arc<Client>,
399    pub user_store: Entity<UserStore>,
400    pub fs: Arc<dyn fs::Fs>,
401    pub node_runtime: NodeRuntime,
402
403    // Additional fields not present in `workspace::AppState`.
404    pub prompt_builder: Arc<PromptBuilder>,
405}
406
407pub fn init(cx: &mut App) -> Arc<AgentAppState> {
408    release_channel::init(SemanticVersion::default(), cx);
409    gpui_tokio::init(cx);
410
411    let mut settings_store = SettingsStore::new(cx);
412    settings_store
413        .set_default_settings(settings::default_settings().as_ref(), cx)
414        .unwrap();
415    cx.set_global(settings_store);
416    client::init_settings(cx);
417
418    // Set User-Agent so we can download language servers from GitHub
419    let user_agent = format!(
420        "Zed/{} ({}; {})",
421        AppVersion::global(cx),
422        std::env::consts::OS,
423        std::env::consts::ARCH
424    );
425    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
426    let proxy_url = proxy_str
427        .as_ref()
428        .and_then(|input| input.parse::<Uri>().ok())
429        .or_else(read_proxy_from_env);
430    let http = {
431        let _guard = Tokio::handle(cx).enter();
432
433        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
434            .expect("could not start HTTP client")
435    };
436    cx.set_http_client(Arc::new(http));
437
438    Project::init_settings(cx);
439
440    let client = Client::production(cx);
441    cx.set_http_client(client.http_client().clone());
442
443    let git_binary_path = None;
444    let fs = Arc::new(RealFs::new(
445        git_binary_path,
446        cx.background_executor().clone(),
447    ));
448
449    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
450    languages.set_language_server_download_dir(paths::languages_dir().clone());
451    let languages = Arc::new(languages);
452
453    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
454
455    extension::init(cx);
456
457    let (tx, rx) = async_watch::channel(None);
458    cx.observe_global::<SettingsStore>(move |cx| {
459        let settings = &ProjectSettings::get_global(cx).node;
460        let options = NodeBinaryOptions {
461            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
462            allow_binary_download: true,
463            use_paths: settings.path.as_ref().map(|node_path| {
464                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
465                let npm_path = settings
466                    .npm_path
467                    .as_ref()
468                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
469                (
470                    node_path.clone(),
471                    npm_path.unwrap_or_else(|| {
472                        let base_path = PathBuf::new();
473                        node_path.parent().unwrap_or(&base_path).join("npm")
474                    }),
475                )
476            }),
477        };
478        tx.send(Some(options)).log_err();
479    })
480    .detach();
481    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
482
483    let extension_host_proxy = ExtensionHostProxy::global(cx);
484
485    language::init(cx);
486    language_extension::init(extension_host_proxy.clone(), languages.clone());
487    language_model::init(client.clone(), cx);
488    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
489    languages::init(languages.clone(), node_runtime.clone(), cx);
490    assistant_tools::init(client.http_client().clone(), cx);
491    context_server::init(cx);
492    prompt_store::init(cx);
493    let stdout_is_a_pty = false;
494    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
495    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
496
497    SettingsStore::update_global(cx, |store, cx| {
498        store.set_user_settings(include_str!("../runner_settings.json"), cx)
499    })
500    .unwrap();
501
502    Arc::new(AgentAppState {
503        languages,
504        client,
505        user_store,
506        fs,
507        node_runtime,
508        prompt_builder,
509    })
510}
511
512pub fn find_model(
513    model_name: &str,
514    model_registry: &LanguageModelRegistry,
515    cx: &App,
516) -> anyhow::Result<Arc<dyn LanguageModel>> {
517    let model = model_registry
518        .available_models(cx)
519        .find(|model| model.id().0 == model_name);
520
521    let Some(model) = model else {
522        return Err(anyhow!(
523            "No language model named {} was available. Available models: {}",
524            model_name,
525            model_registry
526                .available_models(cx)
527                .map(|model| model.id().0.clone())
528                .collect::<Vec<_>>()
529                .join(", ")
530        ));
531    };
532
533    Ok(model)
534}
535
536pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
537    (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
538}
539
540pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
541    futures::executor::block_on(async {
542        get_current_commit_id(repo_path).await.unwrap_or_default()
543    })
544}
545
546async fn run_judge_repetition(
547    example: Example,
548    model: Arc<dyn LanguageModel>,
549    run_output: &RunOutput,
550    round: u32,
551    cx: &AsyncApp,
552) -> Result<JudgeOutput> {
553    let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
554
555    if let Ok(judge_output) = &judge_result {
556        let cohort_id = example
557            .run_directory_path
558            .file_name()
559            .map(|name| name.to_string_lossy().to_string())
560            .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
561
562        let path = std::path::Path::new(".");
563        let commit_id = get_current_commit_id(path).await.unwrap_or_default();
564
565        if let Some(thread) = &judge_output.thread {
566            telemetry::event!(
567                "Agent Eval Completed",
568                cohort_id = cohort_id,
569                example_name = example.name.clone(),
570                round = round,
571                diff_score = judge_output.diff.score,
572                diff_analysis = judge_output.diff.analysis,
573                thread_score = thread.score,
574                thread_analysis = thread.analysis,
575                tool_metrics = run_output.tool_metrics,
576                response_count = run_output.response_count,
577                token_usage = run_output.token_usage,
578                model = model.telemetry_id(),
579                model_provider = model.provider_id().to_string(),
580                repository_url = example.base.url.clone(),
581                repository_revision = example.base.revision.clone(),
582                diagnostics_before = run_output.diagnostics_before,
583                diagnostics_after = run_output.diagnostics_after,
584                commit_id = commit_id
585            );
586        } else {
587            telemetry::event!(
588                "Agent Eval Completed",
589                cohort_id = cohort_id,
590                example_name = example.name.clone(),
591                round = round,
592                diff_score = judge_output.diff.score,
593                diff_analysis = judge_output.diff.analysis,
594                tool_metrics = run_output.tool_metrics,
595                response_count = run_output.response_count,
596                token_usage = run_output.token_usage,
597                model = model.telemetry_id(),
598                model_provider = model.provider_id().to_string(),
599                repository_url = example.base.url.clone(),
600                repository_revision = example.base.revision.clone(),
601                diagnostics_before = run_output.diagnostics_before,
602                diagnostics_after = run_output.diagnostics_after,
603                commit_id = commit_id
604            );
605        }
606    }
607
608    judge_result
609}
610
611fn print_header(header: &str) {
612    println!("\n========================================");
613    println!("{:^40}", header);
614    println!("========================================\n");
615}