eval.rs

  1mod example;
  2mod ids;
  3
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6use telemetry;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use extension::ExtensionHostProxy;
 12use futures::future;
 13use futures::stream::StreamExt;
 14use gpui::http_client::{Uri, read_proxy_from_env};
 15use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 16use gpui_tokio::Tokio;
 17use language::LanguageRegistry;
 18use language_model::{
 19    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 20};
 21use node_runtime::{NodeBinaryOptions, NodeRuntime};
 22use project::Project;
 23use project::project_settings::ProjectSettings;
 24use prompt_store::PromptBuilder;
 25use release_channel::AppVersion;
 26use reqwest_client::ReqwestClient;
 27use settings::{Settings, SettingsStore};
 28use std::collections::HashSet;
 29use std::path::{Path, PathBuf};
 30use std::sync::Arc;
 31use std::usize;
 32use util::ResultExt as _;
 33
 34pub const RUNS_DIR: &str = "./crates/eval/runs";
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "eval", disable_version_flag = true)]
 38struct Args {
 39    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 40    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 41    examples: Vec<String>,
 42    /// Model to use (default: "claude-3-7-sonnet-latest")
 43    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 44    model: String,
 45    #[arg(long, value_delimiter = ',')]
 46    languages: Option<Vec<String>>,
 47    /// How many times to run each example. Note that this is currently not very efficient as N
 48    /// worktrees will be created for the examples.
 49    #[arg(long, default_value = "1")]
 50    repetitions: u32,
 51    /// How many times to run the judge on each example run.
 52    #[arg(long, default_value = "3")]
 53    judge_repetitions: u32,
 54    /// Maximum number of examples to run concurrently.
 55    #[arg(long, default_value = "10")]
 56    concurrency: usize,
 57}
 58
 59fn main() {
 60    env_logger::init();
 61
 62    let args = Args::parse();
 63    let all_available_examples = list_all_examples().unwrap();
 64    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 65
 66    let example_paths = all_available_examples
 67        .iter()
 68        .filter_map(|example_path| {
 69            let name = example_path.file_name()?.to_string_lossy();
 70            if args.examples.is_empty()
 71                || args
 72                    .examples
 73                    .iter()
 74                    .any(|name_substring| name.contains(name_substring))
 75            {
 76                Some(example_path.clone())
 77            } else {
 78                None
 79            }
 80        })
 81        .collect::<Vec<_>>();
 82
 83    let http_client = Arc::new(ReqwestClient::new());
 84    let app = Application::headless().with_http_client(http_client.clone());
 85
 86    app.run(move |cx| {
 87        let app_state = init(cx);
 88
 89        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 90        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 91        let session_id = uuid::Uuid::new_v4().to_string();
 92
 93        app_state
 94            .client
 95            .telemetry()
 96            .start(system_id, installation_id, session_id, cx);
 97
 98        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 99
100        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
101            registry.set_default_model(Some(model.clone()), cx);
102        });
103
104        let model_provider_id = model.provider_id();
105
106        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
107
108        cx.spawn(async move |cx| {
109            authenticate.await.unwrap();
110
111            std::fs::create_dir_all(REPOS_DIR)?;
112            std::fs::create_dir_all(WORKTREES_DIR)?;
113
114            let run_dir = Path::new(RUNS_DIR).join(format!(
115                "{}",
116                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
117            ));
118            std::fs::create_dir_all(&run_dir)?;
119
120            let mut examples = Vec::new();
121
122            const COLORS: [&str; 12] = [
123                "\x1b[31m", // Red
124                "\x1b[32m", // Green
125                "\x1b[33m", // Yellow
126                "\x1b[34m", // Blue
127                "\x1b[35m", // Magenta
128                "\x1b[36m", // Cyan
129                "\x1b[91m", // Bright Red
130                "\x1b[92m", // Bright Green
131                "\x1b[93m", // Bright Yellow
132                "\x1b[94m", // Bright Blue
133                "\x1b[95m", // Bright Magenta
134                "\x1b[96m", // Bright Cyan
135            ];
136
137            let mut max_name_width = 0;
138            let mut skipped = Vec::new();
139
140            for example_path in &example_paths {
141                let example = Example::load_from_directory(example_path, &run_dir)?;
142
143                if !example
144                    .base
145                    .language_extension
146                    .as_ref()
147                    .map_or(false, |lang| languages.contains(lang))
148                {
149                    skipped.push(example.name);
150                    continue;
151                }
152
153                // TODO: This creates a worktree per repetition. Ideally these examples should
154                // either be run sequentially on the same worktree, or reuse worktrees when there
155                // are more examples to run than the concurrency limit.
156                for repetition_number in 0..args.repetitions {
157                    let mut example = example.clone();
158                    example.set_repetition_number(repetition_number);
159
160                    let name_len = example.name.len();
161                    if name_len > max_name_width {
162                        max_name_width = example.name.len();
163                    }
164
165                    examples.push(example);
166                }
167            }
168
169            println!("Skipped examples: {}\n", skipped.join(", "));
170
171            if examples.is_empty() {
172                eprintln!("Filter matched no examples");
173                return cx.update(|cx| cx.quit());
174            }
175
176            let mut repo_urls = HashSet::new();
177            let mut clone_tasks = Vec::new();
178
179            for (i, example) in examples.iter_mut().enumerate() {
180                let color = COLORS[i % COLORS.len()].to_string();
181                example.set_log_prefix_style(&color, max_name_width);
182
183                println!(
184                    "{}Logging to: {}",
185                    example.log_prefix,
186                    example.output_file_path.display()
187                );
188
189                let repo_url = example.base.url.clone();
190                if repo_urls.insert(repo_url.clone()) {
191                    let repo_path = repo_path_for_url(&repo_url);
192
193                    if !repo_path.join(".git").is_dir() {
194                        println!(
195                            "{:<width$}  < {}",
196                            "↓ Cloning",
197                            repo_url,
198                            width = max_name_width
199                        );
200
201                        let git_task = cx.spawn(async move |_cx| {
202                            std::fs::create_dir_all(&repo_path)?;
203                            run_git(&repo_path, &["init"]).await?;
204                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
205                        });
206
207                        clone_tasks.push(git_task);
208                    } else {
209                        println!(
210                            "{:<width$}  < {}",
211                            "✔︎ Already cloned",
212                            repo_url,
213                            width = max_name_width
214                        );
215
216                        let actual_origin =
217                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
218                        if actual_origin != repo_url {
219                            return Err(anyhow!(
220                                "remote origin {} does not match expected origin {}",
221                                actual_origin,
222                                repo_url,
223                            ));
224                        }
225                    }
226                }
227            }
228
229            future::join_all(clone_tasks).await;
230
231            for example in examples.iter_mut() {
232                example.setup().await?;
233            }
234
235            let judge_repetitions = args.judge_repetitions;
236            let concurrency = args.concurrency;
237
238            let tasks = examples
239                .into_iter()
240                .map(|example| {
241                    let app_state = app_state.clone();
242                    let model = model.clone();
243                    cx.spawn(async move |cx| {
244                        let result =
245                            run_example(&example, model, app_state, judge_repetitions, cx).await;
246                        (result, example)
247                    })
248                })
249                .collect::<Vec<_>>();
250
251            let results = futures::stream::iter(tasks)
252                .buffer_unordered(concurrency)
253                .collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
254                .await;
255
256            println!("\n\n");
257            println!("========================================");
258            println!("              EVAL RESULTS              ");
259            println!("========================================");
260            println!("");
261
262            let mut judge_scores = Vec::new();
263
264            for (result, example) in results {
265                match result {
266                    Err(err) => {
267                        println!("💥 {}{:?}", example.log_prefix, err);
268                    }
269                    Ok(judge_results) => {
270                        for judge_result in judge_results {
271                            match judge_result {
272                                Ok(judge_output) => {
273                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
274                                    let score: u32 = judge_output.score;
275                                    let score_index = (score.min(5)) as usize;
276
277                                    println!(
278                                        "{} {}{}",
279                                        SCORES[score_index], example.log_prefix, judge_output.score,
280                                    );
281                                    judge_scores.push(judge_output.score);
282                                }
283                                Err(err) => {
284                                    println!("💥 {}{:?}", example.log_prefix, err);
285                                }
286                            }
287                        }
288                    }
289                }
290                println!(
291                    "{}    > {}",
292                    " ".repeat(max_name_width),
293                    example.output_file_path.display()
294                );
295            }
296
297            let score_count = judge_scores.len();
298            let average_score = judge_scores
299                .into_iter()
300                .map(|score| score as f32)
301                .sum::<f32>()
302                / (score_count as f32);
303            println!("\nAverage score: {average_score}");
304
305            std::thread::sleep(std::time::Duration::from_secs(2));
306
307            app_state.client.telemetry().flush_events();
308
309            cx.update(|cx| cx.quit())
310        })
311        .detach_and_log_err(cx);
312    });
313}
314
315async fn run_example(
316    example: &Example,
317    model: Arc<dyn LanguageModel>,
318    app_state: Arc<AgentAppState>,
319    judge_repetitions: u32,
320    cx: &mut AsyncApp,
321) -> Result<Vec<Result<JudgeOutput>>> {
322    let run_output = cx
323        .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
324        .await?;
325    let diff = example.repository_diff().await?;
326
327    // Run judge for each repetition
328    let mut results = Vec::new();
329    for round in 0..judge_repetitions {
330        let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
331
332        if let Ok(judge_output) = &judge_result {
333            let cohort_id = example
334                .output_file_path
335                .parent()
336                .and_then(|p| p.file_name())
337                .map(|name| name.to_string_lossy().to_string())
338                .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
339
340            let path = std::path::Path::new(".");
341            let commit_id = get_current_commit_id(path).await.unwrap_or_default();
342
343            telemetry::event!(
344                "Agent Eval Completed",
345                cohort_id = cohort_id,
346                example_name = example.name.clone(),
347                round = round,
348                score = judge_output.score,
349                analysis = judge_output.analysis,
350                tool_use_counts = run_output.tool_use_counts,
351                response_count = run_output.response_count,
352                token_usage = run_output.token_usage,
353                model = model.telemetry_id(),
354                model_provider = model.provider_id().to_string(),
355                repository_url = example.base.url.clone(),
356                repository_revision = example.base.revision.clone(),
357                diagnostics_summary = run_output.diagnostics,
358                commit_id = commit_id
359            );
360        }
361
362        results.push(judge_result);
363    }
364
365    app_state.client.telemetry().flush_events();
366
367    Ok(results)
368}
369
370fn list_all_examples() -> Result<Vec<PathBuf>> {
371    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
372    let entries = std::fs::read_dir(path).unwrap();
373    let mut result_paths = Vec::new();
374    for entry in entries {
375        let entry = entry?;
376        let path = entry.path();
377        if path.is_dir() {
378            result_paths.push(path);
379        }
380    }
381    Ok(result_paths)
382}
383
384/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
385pub struct AgentAppState {
386    pub languages: Arc<LanguageRegistry>,
387    pub client: Arc<Client>,
388    pub user_store: Entity<UserStore>,
389    pub fs: Arc<dyn fs::Fs>,
390    pub node_runtime: NodeRuntime,
391
392    // Additional fields not present in `workspace::AppState`.
393    pub prompt_builder: Arc<PromptBuilder>,
394}
395
396pub fn init(cx: &mut App) -> Arc<AgentAppState> {
397    release_channel::init(SemanticVersion::default(), cx);
398    gpui_tokio::init(cx);
399
400    let mut settings_store = SettingsStore::new(cx);
401    settings_store
402        .set_default_settings(settings::default_settings().as_ref(), cx)
403        .unwrap();
404    cx.set_global(settings_store);
405    client::init_settings(cx);
406
407    // Set User-Agent so we can download language servers from GitHub
408    let user_agent = format!(
409        "Zed/{} ({}; {})",
410        AppVersion::global(cx),
411        std::env::consts::OS,
412        std::env::consts::ARCH
413    );
414    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
415    let proxy_url = proxy_str
416        .as_ref()
417        .and_then(|input| input.parse::<Uri>().ok())
418        .or_else(read_proxy_from_env);
419    let http = {
420        let _guard = Tokio::handle(cx).enter();
421
422        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
423            .expect("could not start HTTP client")
424    };
425    cx.set_http_client(Arc::new(http));
426
427    Project::init_settings(cx);
428
429    let client = Client::production(cx);
430    cx.set_http_client(client.http_client().clone());
431
432    let git_binary_path = None;
433    let fs = Arc::new(RealFs::new(
434        git_binary_path,
435        cx.background_executor().clone(),
436    ));
437
438    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
439    languages.set_language_server_download_dir(paths::languages_dir().clone());
440    let languages = Arc::new(languages);
441
442    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
443
444    extension::init(cx);
445
446    let (tx, rx) = async_watch::channel(None);
447    cx.observe_global::<SettingsStore>(move |cx| {
448        let settings = &ProjectSettings::get_global(cx).node;
449        let options = NodeBinaryOptions {
450            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
451            allow_binary_download: true,
452            use_paths: settings.path.as_ref().map(|node_path| {
453                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
454                let npm_path = settings
455                    .npm_path
456                    .as_ref()
457                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
458                (
459                    node_path.clone(),
460                    npm_path.unwrap_or_else(|| {
461                        let base_path = PathBuf::new();
462                        node_path.parent().unwrap_or(&base_path).join("npm")
463                    }),
464                )
465            }),
466        };
467        tx.send(Some(options)).log_err();
468    })
469    .detach();
470    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
471
472    let extension_host_proxy = ExtensionHostProxy::global(cx);
473
474    language::init(cx);
475    language_extension::init(extension_host_proxy.clone(), languages.clone());
476    language_model::init(client.clone(), cx);
477    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
478    languages::init(languages.clone(), node_runtime.clone(), cx);
479    assistant_tools::init(client.http_client().clone(), cx);
480    context_server::init(cx);
481    prompt_store::init(cx);
482    let stdout_is_a_pty = false;
483    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
484    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
485
486    SettingsStore::update_global(cx, |store, cx| {
487        store.set_user_settings(include_str!("../runner_settings.json"), cx)
488    })
489    .unwrap();
490
491    Arc::new(AgentAppState {
492        languages,
493        client,
494        user_store,
495        fs,
496        node_runtime,
497        prompt_builder,
498    })
499}
500
501pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
502    let model_registry = LanguageModelRegistry::read_global(cx);
503    let model = model_registry
504        .available_models(cx)
505        .find(|model| model.id().0 == model_name);
506
507    let Some(model) = model else {
508        return Err(anyhow!(
509            "No language model named {} was available. Available models: {}",
510            model_name,
511            model_registry
512                .available_models(cx)
513                .map(|model| model.id().0.clone())
514                .collect::<Vec<_>>()
515                .join(", ")
516        ));
517    };
518
519    Ok(model)
520}
521
522pub fn authenticate_model_provider(
523    provider_id: LanguageModelProviderId,
524    cx: &mut App,
525) -> Task<std::result::Result<(), AuthenticateError>> {
526    let model_registry = LanguageModelRegistry::read_global(cx);
527    let model_provider = model_registry.provider(&provider_id).unwrap();
528    model_provider.authenticate(cx)
529}
530
531pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
532    (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
533}
534
535pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
536    futures::executor::block_on(async {
537        get_current_commit_id(repo_path).await.unwrap_or_default()
538    })
539}