eval.rs

  1mod example;
  2mod ids;
  3
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6use telemetry;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use extension::ExtensionHostProxy;
 12use futures::{StreamExt, future};
 13use gpui::http_client::{Uri, read_proxy_from_env};
 14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 15use gpui_tokio::Tokio;
 16use language::LanguageRegistry;
 17use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 18use node_runtime::{NodeBinaryOptions, NodeRuntime};
 19use project::Project;
 20use project::project_settings::ProjectSettings;
 21use prompt_store::PromptBuilder;
 22use release_channel::AppVersion;
 23use reqwest_client::ReqwestClient;
 24use settings::{Settings, SettingsStore};
 25use std::collections::HashSet;
 26use std::path::{Path, PathBuf};
 27use std::sync::Arc;
 28use std::usize;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42    #[arg(long, value_delimiter = ',')]
 43    languages: Option<Vec<String>>,
 44    /// How many times to run each example. Note that this is currently not very efficient as N
 45    /// worktrees will be created for the examples.
 46    #[arg(long, default_value = "1")]
 47    repetitions: u32,
 48    /// How many times to run the judge on each example run.
 49    #[arg(long, default_value = "3")]
 50    judge_repetitions: u32,
 51    /// Maximum number of examples to run concurrently.
 52    #[arg(long, default_value = "10")]
 53    concurrency: usize,
 54}
 55
 56fn main() {
 57    env_logger::init();
 58
 59    let args = Args::parse();
 60    let all_available_examples = list_all_examples().unwrap();
 61    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 62
 63    let example_paths = all_available_examples
 64        .iter()
 65        .filter_map(|example_path| {
 66            let name = example_path.file_name()?.to_string_lossy();
 67            if args.examples.is_empty()
 68                || args
 69                    .examples
 70                    .iter()
 71                    .any(|name_substring| name.contains(name_substring))
 72            {
 73                Some(example_path.clone())
 74            } else {
 75                None
 76            }
 77        })
 78        .collect::<Vec<_>>();
 79
 80    let http_client = Arc::new(ReqwestClient::new());
 81    let app = Application::headless().with_http_client(http_client.clone());
 82
 83    app.run(move |cx| {
 84        let app_state = init(cx);
 85
 86        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 87        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 88        let session_id = uuid::Uuid::new_v4().to_string();
 89
 90        app_state
 91            .client
 92            .telemetry()
 93            .start(system_id, installation_id, session_id, cx);
 94
 95        let model_registry = LanguageModelRegistry::read_global(cx);
 96        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
 97        let model_provider_id = model.provider_id();
 98        let model_provider = model_registry.provider(&model_provider_id).unwrap();
 99
100        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
101            registry.set_default_model(
102                Some(ConfiguredModel {
103                    provider: model_provider.clone(),
104                    model: model.clone(),
105                }),
106                cx,
107            );
108        });
109
110        let authenticate_task = model_provider.authenticate(cx);
111
112        cx.spawn(async move |cx| {
113            authenticate_task.await.unwrap();
114
115            std::fs::create_dir_all(REPOS_DIR)?;
116            std::fs::create_dir_all(WORKTREES_DIR)?;
117
118            let run_dir = Path::new(RUNS_DIR).join(format!(
119                "{}",
120                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
121            ));
122            std::fs::create_dir_all(&run_dir)?;
123
124            let mut examples = Vec::new();
125
126            const COLORS: [&str; 12] = [
127                "\x1b[31m", // Red
128                "\x1b[32m", // Green
129                "\x1b[33m", // Yellow
130                "\x1b[34m", // Blue
131                "\x1b[35m", // Magenta
132                "\x1b[36m", // Cyan
133                "\x1b[91m", // Bright Red
134                "\x1b[92m", // Bright Green
135                "\x1b[93m", // Bright Yellow
136                "\x1b[94m", // Bright Blue
137                "\x1b[95m", // Bright Magenta
138                "\x1b[96m", // Bright Cyan
139            ];
140
141            let mut max_name_width = 0;
142            let mut skipped = Vec::new();
143
144            for example_path in &example_paths {
145                let example = Example::load_from_directory(example_path, &run_dir)?;
146
147                if !example
148                    .base
149                    .language_extension
150                    .as_ref()
151                    .map_or(false, |lang| languages.contains(lang))
152                {
153                    skipped.push(example.name);
154                    continue;
155                }
156
157                // TODO: This creates a worktree per repetition. Ideally these examples should
158                // either be run sequentially on the same worktree, or reuse worktrees when there
159                // are more examples to run than the concurrency limit.
160                for repetition_number in 0..args.repetitions {
161                    let mut example = example.clone();
162                    example.set_repetition_number(repetition_number);
163
164                    let name_len = example.name.len();
165                    if name_len > max_name_width {
166                        max_name_width = example.name.len();
167                    }
168
169                    examples.push(example);
170                }
171            }
172
173            println!("Skipped examples: {}\n", skipped.join(", "));
174
175            if examples.is_empty() {
176                eprintln!("Filter matched no examples");
177                return cx.update(|cx| cx.quit());
178            }
179
180            let mut repo_urls = HashSet::new();
181            let mut clone_tasks = Vec::new();
182
183            for (i, example) in examples.iter_mut().enumerate() {
184                let color = COLORS[i % COLORS.len()].to_string();
185                example.set_log_prefix_style(&color, max_name_width);
186
187                println!(
188                    "{}Logging to: {}",
189                    example.log_prefix,
190                    example.example_output_directory().display()
191                );
192
193                let repo_url = example.base.url.clone();
194                if repo_urls.insert(repo_url.clone()) {
195                    let repo_path = repo_path_for_url(&repo_url);
196
197                    if !repo_path.join(".git").is_dir() {
198                        println!(
199                            "{:<width$} < {}",
200                            "↓ Cloning",
201                            repo_url,
202                            width = max_name_width
203                        );
204
205                        let git_task = cx.spawn(async move |_cx| {
206                            std::fs::create_dir_all(&repo_path)?;
207                            run_git(&repo_path, &["init"]).await?;
208                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
209                        });
210
211                        clone_tasks.push(git_task);
212                    } else {
213                        println!(
214                            "{:<width$}  < {}",
215                            "✔︎ Already cloned",
216                            repo_url,
217                            width = max_name_width
218                        );
219
220                        let actual_origin =
221                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
222                        if actual_origin != repo_url {
223                            return Err(anyhow!(
224                                "remote origin {} does not match expected origin {}",
225                                actual_origin,
226                                repo_url,
227                            ));
228                        }
229                    }
230                }
231            }
232
233            future::join_all(clone_tasks).await;
234
235            for example in examples.iter_mut() {
236                example.setup().await?;
237            }
238
239            let judge_repetitions = args.judge_repetitions;
240            let concurrency = args.concurrency;
241
242            let tasks = examples.iter().map(|example| {
243                let app_state = app_state.clone();
244                let model = model.clone();
245                let example = example.clone();
246                cx.spawn(async move |cx| {
247                    let result =
248                        run_example(&example, model, app_state, judge_repetitions, cx).await;
249                    (result, example)
250                })
251            });
252
253            let results = futures::stream::iter(tasks)
254                .buffer_unordered(concurrency)
255                .collect::<Vec<_>>()
256                .await;
257
258            println!("\n\n");
259            println!("========================================");
260            println!("              EVAL RESULTS              ");
261            println!("========================================");
262            println!("");
263
264            let mut diff_scores = Vec::new();
265            let mut thread_scores = Vec::new();
266            let mut error_count = 0;
267
268            for (result, example) in results {
269                match result {
270                    Err(err) => {
271                        println!("💥 {}{:?}", example.log_prefix, err);
272                        error_count += 1;
273                    }
274                    Ok(judge_results) => {
275                        for judge_result in judge_results {
276                            match judge_result {
277                                Ok(judge_output) => {
278                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
279                                    let diff_score: u32 = judge_output.diff.score;
280                                    let score_index = (diff_score.min(5)) as usize;
281
282                                    println!(
283                                        "{} {}{} (Diff)",
284                                        SCORES[score_index],
285                                        example.log_prefix,
286                                        judge_output.diff.score,
287                                    );
288                                    diff_scores.push(judge_output.diff.score);
289
290                                    if let Some(thread) = judge_output.thread {
291                                        let process_score: u32 = thread.score;
292                                        let score_index = (process_score.min(5)) as usize;
293                                        println!(
294                                            "{} {}{} (Thread)",
295                                            SCORES[score_index], example.log_prefix, thread.score,
296                                        );
297                                        thread_scores.push(thread.score);
298                                    }
299                                }
300                                Err(err) => {
301                                    println!("💥 {}{:?}", example.log_prefix, err);
302                                }
303                            }
304                        }
305                    }
306                }
307                println!(
308                    "{}    > {}",
309                    " ".repeat(max_name_width),
310                    example.example_output_directory().display()
311                );
312            }
313
314            let diff_score_count = diff_scores.len();
315            let average_diff_score = diff_scores
316                .into_iter()
317                .map(|score| score as f32)
318                .sum::<f32>()
319                / (diff_score_count as f32);
320
321            if error_count > 0 {
322                println!("\n{error_count} examples failed to run!");
323            }
324
325            if diff_score_count > 0 {
326                println!("\nAverage code diff score: {average_diff_score}");
327            }
328
329            let thread_score_count = thread_scores.len();
330
331            // We might have gotten no thread scores if we weren't asked to judge the thread.
332            if thread_score_count > 0 {
333                let average_thread_score = thread_scores
334                    .into_iter()
335                    .map(|score| score as f32)
336                    .sum::<f32>()
337                    / (thread_score_count as f32);
338
339                if diff_score_count > 0 {
340                    println!("\nAverage thread score: {average_thread_score}");
341                }
342            }
343
344            std::thread::sleep(std::time::Duration::from_secs(2));
345
346            app_state.client.telemetry().flush_events();
347
348            cx.update(|cx| cx.quit())
349        })
350        .detach_and_log_err(cx);
351    });
352}
353
354async fn run_example(
355    example: &Example,
356    model: Arc<dyn LanguageModel>,
357    app_state: Arc<AgentAppState>,
358    judge_repetitions: u32,
359    cx: &mut AsyncApp,
360) -> Result<Vec<Result<JudgeOutput>>> {
361    let run_output = cx
362        .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
363        .await?;
364
365    let judge_tasks = (0..judge_repetitions)
366        .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
367
368    let results = future::join_all(judge_tasks).await;
369
370    app_state.client.telemetry().flush_events();
371
372    Ok(results)
373}
374
375fn list_all_examples() -> Result<Vec<PathBuf>> {
376    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
377    let entries = std::fs::read_dir(path).unwrap();
378    let mut result_paths = Vec::new();
379    for entry in entries {
380        let entry = entry?;
381        let path = entry.path();
382        if path.is_dir() {
383            result_paths.push(path);
384        }
385    }
386    Ok(result_paths)
387}
388
389/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
390pub struct AgentAppState {
391    pub languages: Arc<LanguageRegistry>,
392    pub client: Arc<Client>,
393    pub user_store: Entity<UserStore>,
394    pub fs: Arc<dyn fs::Fs>,
395    pub node_runtime: NodeRuntime,
396
397    // Additional fields not present in `workspace::AppState`.
398    pub prompt_builder: Arc<PromptBuilder>,
399}
400
401pub fn init(cx: &mut App) -> Arc<AgentAppState> {
402    release_channel::init(SemanticVersion::default(), cx);
403    gpui_tokio::init(cx);
404
405    let mut settings_store = SettingsStore::new(cx);
406    settings_store
407        .set_default_settings(settings::default_settings().as_ref(), cx)
408        .unwrap();
409    cx.set_global(settings_store);
410    client::init_settings(cx);
411
412    // Set User-Agent so we can download language servers from GitHub
413    let user_agent = format!(
414        "Zed/{} ({}; {})",
415        AppVersion::global(cx),
416        std::env::consts::OS,
417        std::env::consts::ARCH
418    );
419    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
420    let proxy_url = proxy_str
421        .as_ref()
422        .and_then(|input| input.parse::<Uri>().ok())
423        .or_else(read_proxy_from_env);
424    let http = {
425        let _guard = Tokio::handle(cx).enter();
426
427        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
428            .expect("could not start HTTP client")
429    };
430    cx.set_http_client(Arc::new(http));
431
432    Project::init_settings(cx);
433
434    let client = Client::production(cx);
435    cx.set_http_client(client.http_client().clone());
436
437    let git_binary_path = None;
438    let fs = Arc::new(RealFs::new(
439        git_binary_path,
440        cx.background_executor().clone(),
441    ));
442
443    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
444    languages.set_language_server_download_dir(paths::languages_dir().clone());
445    let languages = Arc::new(languages);
446
447    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
448
449    extension::init(cx);
450
451    let (tx, rx) = async_watch::channel(None);
452    cx.observe_global::<SettingsStore>(move |cx| {
453        let settings = &ProjectSettings::get_global(cx).node;
454        let options = NodeBinaryOptions {
455            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
456            allow_binary_download: true,
457            use_paths: settings.path.as_ref().map(|node_path| {
458                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
459                let npm_path = settings
460                    .npm_path
461                    .as_ref()
462                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
463                (
464                    node_path.clone(),
465                    npm_path.unwrap_or_else(|| {
466                        let base_path = PathBuf::new();
467                        node_path.parent().unwrap_or(&base_path).join("npm")
468                    }),
469                )
470            }),
471        };
472        tx.send(Some(options)).log_err();
473    })
474    .detach();
475    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
476
477    let extension_host_proxy = ExtensionHostProxy::global(cx);
478
479    language::init(cx);
480    language_extension::init(extension_host_proxy.clone(), languages.clone());
481    language_model::init(client.clone(), cx);
482    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
483    languages::init(languages.clone(), node_runtime.clone(), cx);
484    assistant_tools::init(client.http_client().clone(), cx);
485    context_server::init(cx);
486    prompt_store::init(cx);
487    let stdout_is_a_pty = false;
488    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
489    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
490
491    SettingsStore::update_global(cx, |store, cx| {
492        store.set_user_settings(include_str!("../runner_settings.json"), cx)
493    })
494    .unwrap();
495
496    Arc::new(AgentAppState {
497        languages,
498        client,
499        user_store,
500        fs,
501        node_runtime,
502        prompt_builder,
503    })
504}
505
506pub fn find_model(
507    model_name: &str,
508    model_registry: &LanguageModelRegistry,
509    cx: &App,
510) -> anyhow::Result<Arc<dyn LanguageModel>> {
511    let model = model_registry
512        .available_models(cx)
513        .find(|model| model.id().0 == model_name);
514
515    let Some(model) = model else {
516        return Err(anyhow!(
517            "No language model named {} was available. Available models: {}",
518            model_name,
519            model_registry
520                .available_models(cx)
521                .map(|model| model.id().0.clone())
522                .collect::<Vec<_>>()
523                .join(", ")
524        ));
525    };
526
527    Ok(model)
528}
529
530pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
531    (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
532}
533
534pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
535    futures::executor::block_on(async {
536        get_current_commit_id(repo_path).await.unwrap_or_default()
537    })
538}
539
540async fn run_judge_repetition(
541    example: Example,
542    model: Arc<dyn LanguageModel>,
543    run_output: &RunOutput,
544    round: u32,
545    cx: &AsyncApp,
546) -> Result<JudgeOutput> {
547    let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
548
549    if let Ok(judge_output) = &judge_result {
550        let cohort_id = example
551            .run_directory_path
552            .file_name()
553            .map(|name| name.to_string_lossy().to_string())
554            .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
555
556        let path = std::path::Path::new(".");
557        let commit_id = get_current_commit_id(path).await.unwrap_or_default();
558
559        if let Some(thread) = &judge_output.thread {
560            telemetry::event!(
561                "Agent Eval Completed",
562                cohort_id = cohort_id,
563                example_name = example.name.clone(),
564                round = round,
565                diff_score = judge_output.diff.score,
566                diff_analysis = judge_output.diff.analysis,
567                thread_score = thread.score,
568                thread_analysis = thread.analysis,
569                tool_use_counts = run_output.tool_use_counts,
570                response_count = run_output.response_count,
571                token_usage = run_output.token_usage,
572                model = model.telemetry_id(),
573                model_provider = model.provider_id().to_string(),
574                repository_url = example.base.url.clone(),
575                repository_revision = example.base.revision.clone(),
576                diagnostics_before = run_output.diagnostics_before,
577                diagnostics_after = run_output.diagnostics_after,
578                commit_id = commit_id
579            );
580        } else {
581            telemetry::event!(
582                "Agent Eval Completed",
583                cohort_id = cohort_id,
584                example_name = example.name.clone(),
585                round = round,
586                diff_score = judge_output.diff.score,
587                diff_analysis = judge_output.diff.analysis,
588                tool_use_counts = run_output.tool_use_counts,
589                response_count = run_output.response_count,
590                token_usage = run_output.token_usage,
591                model = model.telemetry_id(),
592                model_provider = model.provider_id().to_string(),
593                repository_url = example.base.url.clone(),
594                repository_revision = example.base.revision.clone(),
595                diagnostics_before = run_output.diagnostics_before,
596                diagnostics_after = run_output.diagnostics_after,
597                commit_id = commit_id
598            );
599        }
600    }
601
602    judge_result
603}