eval.rs

  1mod example;
  2mod ids;
  3
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6use telemetry;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use extension::ExtensionHostProxy;
 12use futures::{StreamExt, future};
 13use gpui::http_client::{Uri, read_proxy_from_env};
 14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 15use gpui_tokio::Tokio;
 16use language::LanguageRegistry;
 17use language_model::{
 18    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 19};
 20use node_runtime::{NodeBinaryOptions, NodeRuntime};
 21use project::Project;
 22use project::project_settings::ProjectSettings;
 23use prompt_store::PromptBuilder;
 24use release_channel::AppVersion;
 25use reqwest_client::ReqwestClient;
 26use settings::{Settings, SettingsStore};
 27use std::collections::HashSet;
 28use std::path::{Path, PathBuf};
 29use std::sync::Arc;
 30use std::usize;
 31use util::ResultExt as _;
 32
 33pub const RUNS_DIR: &str = "./crates/eval/runs";
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "eval", disable_version_flag = true)]
 37struct Args {
 38    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 39    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 40    examples: Vec<String>,
 41    /// Model to use (default: "claude-3-7-sonnet-latest")
 42    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 43    model: String,
 44    #[arg(long, value_delimiter = ',')]
 45    languages: Option<Vec<String>>,
 46    /// How many times to run each example. Note that this is currently not very efficient as N
 47    /// worktrees will be created for the examples.
 48    #[arg(long, default_value = "1")]
 49    repetitions: u32,
 50    /// How many times to run the judge on each example run.
 51    #[arg(long, default_value = "3")]
 52    judge_repetitions: u32,
 53    /// Maximum number of examples to run concurrently.
 54    #[arg(long, default_value = "10")]
 55    concurrency: usize,
 56}
 57
 58fn main() {
 59    env_logger::init();
 60
 61    let args = Args::parse();
 62    let all_available_examples = list_all_examples().unwrap();
 63    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 64
 65    let example_paths = all_available_examples
 66        .iter()
 67        .filter_map(|example_path| {
 68            let name = example_path.file_name()?.to_string_lossy();
 69            if args.examples.is_empty()
 70                || args
 71                    .examples
 72                    .iter()
 73                    .any(|name_substring| name.contains(name_substring))
 74            {
 75                Some(example_path.clone())
 76            } else {
 77                None
 78            }
 79        })
 80        .collect::<Vec<_>>();
 81
 82    let http_client = Arc::new(ReqwestClient::new());
 83    let app = Application::headless().with_http_client(http_client.clone());
 84
 85    app.run(move |cx| {
 86        let app_state = init(cx);
 87
 88        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 89        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 90        let session_id = uuid::Uuid::new_v4().to_string();
 91
 92        app_state
 93            .client
 94            .telemetry()
 95            .start(system_id, installation_id, session_id, cx);
 96
 97        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 98
 99        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
100            registry.set_default_model(Some(model.clone()), cx);
101        });
102
103        let model_provider_id = model.provider_id();
104
105        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
106
107        cx.spawn(async move |cx| {
108            authenticate.await.unwrap();
109
110            std::fs::create_dir_all(REPOS_DIR)?;
111            std::fs::create_dir_all(WORKTREES_DIR)?;
112
113            let run_dir = Path::new(RUNS_DIR).join(format!(
114                "{}",
115                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
116            ));
117            std::fs::create_dir_all(&run_dir)?;
118
119            let mut examples = Vec::new();
120
121            const COLORS: [&str; 12] = [
122                "\x1b[31m", // Red
123                "\x1b[32m", // Green
124                "\x1b[33m", // Yellow
125                "\x1b[34m", // Blue
126                "\x1b[35m", // Magenta
127                "\x1b[36m", // Cyan
128                "\x1b[91m", // Bright Red
129                "\x1b[92m", // Bright Green
130                "\x1b[93m", // Bright Yellow
131                "\x1b[94m", // Bright Blue
132                "\x1b[95m", // Bright Magenta
133                "\x1b[96m", // Bright Cyan
134            ];
135
136            let mut max_name_width = 0;
137            let mut skipped = Vec::new();
138
139            for example_path in &example_paths {
140                let example = Example::load_from_directory(example_path, &run_dir)?;
141
142                if !example
143                    .base
144                    .language_extension
145                    .as_ref()
146                    .map_or(false, |lang| languages.contains(lang))
147                {
148                    skipped.push(example.name);
149                    continue;
150                }
151
152                // TODO: This creates a worktree per repetition. Ideally these examples should
153                // either be run sequentially on the same worktree, or reuse worktrees when there
154                // are more examples to run than the concurrency limit.
155                for repetition_number in 0..args.repetitions {
156                    let mut example = example.clone();
157                    example.set_repetition_number(repetition_number);
158
159                    let name_len = example.name.len();
160                    if name_len > max_name_width {
161                        max_name_width = example.name.len();
162                    }
163
164                    examples.push(example);
165                }
166            }
167
168            println!("Skipped examples: {}\n", skipped.join(", "));
169
170            if examples.is_empty() {
171                eprintln!("Filter matched no examples");
172                return cx.update(|cx| cx.quit());
173            }
174
175            let mut repo_urls = HashSet::new();
176            let mut clone_tasks = Vec::new();
177
178            for (i, example) in examples.iter_mut().enumerate() {
179                let color = COLORS[i % COLORS.len()].to_string();
180                example.set_log_prefix_style(&color, max_name_width);
181
182                println!(
183                    "{}Logging to: {}",
184                    example.log_prefix,
185                    example.example_output_directory().display()
186                );
187
188                let repo_url = example.base.url.clone();
189                if repo_urls.insert(repo_url.clone()) {
190                    let repo_path = repo_path_for_url(&repo_url);
191
192                    if !repo_path.join(".git").is_dir() {
193                        println!(
194                            "{:<width$} < {}",
195                            "↓ Cloning",
196                            repo_url,
197                            width = max_name_width
198                        );
199
200                        let git_task = cx.spawn(async move |_cx| {
201                            std::fs::create_dir_all(&repo_path)?;
202                            run_git(&repo_path, &["init"]).await?;
203                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
204                        });
205
206                        clone_tasks.push(git_task);
207                    } else {
208                        println!(
209                            "{:<width$}  < {}",
210                            "✔︎ Already cloned",
211                            repo_url,
212                            width = max_name_width
213                        );
214
215                        let actual_origin =
216                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
217                        if actual_origin != repo_url {
218                            return Err(anyhow!(
219                                "remote origin {} does not match expected origin {}",
220                                actual_origin,
221                                repo_url,
222                            ));
223                        }
224                    }
225                }
226            }
227
228            future::join_all(clone_tasks).await;
229
230            for example in examples.iter_mut() {
231                example.setup().await?;
232            }
233
234            let judge_repetitions = args.judge_repetitions;
235            let concurrency = args.concurrency;
236
237            let tasks = examples.iter().map(|example| {
238                let app_state = app_state.clone();
239                let model = model.clone();
240                let example = example.clone();
241                cx.spawn(async move |cx| {
242                    let result =
243                        run_example(&example, model, app_state, judge_repetitions, cx).await;
244                    (result, example)
245                })
246            });
247
248            let results = futures::stream::iter(tasks)
249                .buffer_unordered(concurrency)
250                .collect::<Vec<_>>()
251                .await;
252
253            println!("\n\n");
254            println!("========================================");
255            println!("              EVAL RESULTS              ");
256            println!("========================================");
257            println!("");
258
259            let mut diff_scores = Vec::new();
260            let mut thread_scores = Vec::new();
261            let mut error_count = 0;
262
263            for (result, example) in results {
264                match result {
265                    Err(err) => {
266                        println!("💥 {}{:?}", example.log_prefix, err);
267                        error_count += 1;
268                    }
269                    Ok(judge_results) => {
270                        for judge_result in judge_results {
271                            match judge_result {
272                                Ok(judge_output) => {
273                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
274                                    let diff_score: u32 = judge_output.diff.score;
275                                    let score_index = (diff_score.min(5)) as usize;
276
277                                    println!(
278                                        "{} {}{} (Diff)",
279                                        SCORES[score_index],
280                                        example.log_prefix,
281                                        judge_output.diff.score,
282                                    );
283                                    diff_scores.push(judge_output.diff.score);
284
285                                    if let Some(thread) = judge_output.thread {
286                                        let process_score: u32 = thread.score;
287                                        let score_index = (process_score.min(5)) as usize;
288                                        println!(
289                                            "{} {}{} (Thread)",
290                                            SCORES[score_index], example.log_prefix, thread.score,
291                                        );
292                                        thread_scores.push(thread.score);
293                                    }
294                                }
295                                Err(err) => {
296                                    println!("💥 {}{:?}", example.log_prefix, err);
297                                }
298                            }
299                        }
300                    }
301                }
302                println!(
303                    "{}    > {}",
304                    " ".repeat(max_name_width),
305                    example.example_output_directory().display()
306                );
307            }
308
309            let diff_score_count = diff_scores.len();
310            let average_diff_score = diff_scores
311                .into_iter()
312                .map(|score| score as f32)
313                .sum::<f32>()
314                / (diff_score_count as f32);
315
316            if error_count > 0 {
317                println!("\n{error_count} examples failed to run!");
318            }
319
320            if diff_score_count > 0 {
321                println!("\nAverage code diff score: {average_diff_score}");
322            }
323
324            let thread_score_count = thread_scores.len();
325
326            // We might have gotten no thread scores if we weren't asked to judge the thread.
327            if thread_score_count > 0 {
328                let average_thread_score = thread_scores
329                    .into_iter()
330                    .map(|score| score as f32)
331                    .sum::<f32>()
332                    / (thread_score_count as f32);
333
334                if diff_score_count > 0 {
335                    println!("\nAverage thread score: {average_thread_score}");
336                }
337            }
338
339            std::thread::sleep(std::time::Duration::from_secs(2));
340
341            app_state.client.telemetry().flush_events();
342
343            cx.update(|cx| cx.quit())
344        })
345        .detach_and_log_err(cx);
346    });
347}
348
349async fn run_example(
350    example: &Example,
351    model: Arc<dyn LanguageModel>,
352    app_state: Arc<AgentAppState>,
353    judge_repetitions: u32,
354    cx: &mut AsyncApp,
355) -> Result<Vec<Result<JudgeOutput>>> {
356    let run_output = cx
357        .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
358        .await?;
359
360    let judge_tasks = (0..judge_repetitions)
361        .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx));
362
363    let results = future::join_all(judge_tasks).await;
364
365    app_state.client.telemetry().flush_events();
366
367    Ok(results)
368}
369
370fn list_all_examples() -> Result<Vec<PathBuf>> {
371    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
372    let entries = std::fs::read_dir(path).unwrap();
373    let mut result_paths = Vec::new();
374    for entry in entries {
375        let entry = entry?;
376        let path = entry.path();
377        if path.is_dir() {
378            result_paths.push(path);
379        }
380    }
381    Ok(result_paths)
382}
383
384/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
385pub struct AgentAppState {
386    pub languages: Arc<LanguageRegistry>,
387    pub client: Arc<Client>,
388    pub user_store: Entity<UserStore>,
389    pub fs: Arc<dyn fs::Fs>,
390    pub node_runtime: NodeRuntime,
391
392    // Additional fields not present in `workspace::AppState`.
393    pub prompt_builder: Arc<PromptBuilder>,
394}
395
396pub fn init(cx: &mut App) -> Arc<AgentAppState> {
397    release_channel::init(SemanticVersion::default(), cx);
398    gpui_tokio::init(cx);
399
400    let mut settings_store = SettingsStore::new(cx);
401    settings_store
402        .set_default_settings(settings::default_settings().as_ref(), cx)
403        .unwrap();
404    cx.set_global(settings_store);
405    client::init_settings(cx);
406
407    // Set User-Agent so we can download language servers from GitHub
408    let user_agent = format!(
409        "Zed/{} ({}; {})",
410        AppVersion::global(cx),
411        std::env::consts::OS,
412        std::env::consts::ARCH
413    );
414    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
415    let proxy_url = proxy_str
416        .as_ref()
417        .and_then(|input| input.parse::<Uri>().ok())
418        .or_else(read_proxy_from_env);
419    let http = {
420        let _guard = Tokio::handle(cx).enter();
421
422        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
423            .expect("could not start HTTP client")
424    };
425    cx.set_http_client(Arc::new(http));
426
427    Project::init_settings(cx);
428
429    let client = Client::production(cx);
430    cx.set_http_client(client.http_client().clone());
431
432    let git_binary_path = None;
433    let fs = Arc::new(RealFs::new(
434        git_binary_path,
435        cx.background_executor().clone(),
436    ));
437
438    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
439    languages.set_language_server_download_dir(paths::languages_dir().clone());
440    let languages = Arc::new(languages);
441
442    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
443
444    extension::init(cx);
445
446    let (tx, rx) = async_watch::channel(None);
447    cx.observe_global::<SettingsStore>(move |cx| {
448        let settings = &ProjectSettings::get_global(cx).node;
449        let options = NodeBinaryOptions {
450            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
451            allow_binary_download: true,
452            use_paths: settings.path.as_ref().map(|node_path| {
453                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
454                let npm_path = settings
455                    .npm_path
456                    .as_ref()
457                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
458                (
459                    node_path.clone(),
460                    npm_path.unwrap_or_else(|| {
461                        let base_path = PathBuf::new();
462                        node_path.parent().unwrap_or(&base_path).join("npm")
463                    }),
464                )
465            }),
466        };
467        tx.send(Some(options)).log_err();
468    })
469    .detach();
470    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
471
472    let extension_host_proxy = ExtensionHostProxy::global(cx);
473
474    language::init(cx);
475    language_extension::init(extension_host_proxy.clone(), languages.clone());
476    language_model::init(client.clone(), cx);
477    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
478    languages::init(languages.clone(), node_runtime.clone(), cx);
479    assistant_tools::init(client.http_client().clone(), cx);
480    context_server::init(cx);
481    prompt_store::init(cx);
482    let stdout_is_a_pty = false;
483    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
484    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
485
486    SettingsStore::update_global(cx, |store, cx| {
487        store.set_user_settings(include_str!("../runner_settings.json"), cx)
488    })
489    .unwrap();
490
491    Arc::new(AgentAppState {
492        languages,
493        client,
494        user_store,
495        fs,
496        node_runtime,
497        prompt_builder,
498    })
499}
500
501pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
502    let model_registry = LanguageModelRegistry::read_global(cx);
503    let model = model_registry
504        .available_models(cx)
505        .find(|model| model.id().0 == model_name);
506
507    let Some(model) = model else {
508        return Err(anyhow!(
509            "No language model named {} was available. Available models: {}",
510            model_name,
511            model_registry
512                .available_models(cx)
513                .map(|model| model.id().0.clone())
514                .collect::<Vec<_>>()
515                .join(", ")
516        ));
517    };
518
519    Ok(model)
520}
521
522pub fn authenticate_model_provider(
523    provider_id: LanguageModelProviderId,
524    cx: &mut App,
525) -> Task<std::result::Result<(), AuthenticateError>> {
526    let model_registry = LanguageModelRegistry::read_global(cx);
527    let model_provider = model_registry.provider(&provider_id).unwrap();
528    model_provider.authenticate(cx)
529}
530
531pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
532    (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
533}
534
535pub fn get_current_commit_id_sync(repo_path: &Path) -> String {
536    futures::executor::block_on(async {
537        get_current_commit_id(repo_path).await.unwrap_or_default()
538    })
539}
540
541async fn run_judge_repetition(
542    example: Example,
543    model: Arc<dyn LanguageModel>,
544    run_output: &RunOutput,
545    round: u32,
546    cx: &AsyncApp,
547) -> Result<JudgeOutput> {
548    let judge_result = example.judge(model.clone(), &run_output, round, cx).await;
549
550    if let Ok(judge_output) = &judge_result {
551        let cohort_id = example
552            .run_directory_path
553            .file_name()
554            .map(|name| name.to_string_lossy().to_string())
555            .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
556
557        let path = std::path::Path::new(".");
558        let commit_id = get_current_commit_id(path).await.unwrap_or_default();
559
560        if let Some(thread) = &judge_output.thread {
561            telemetry::event!(
562                "Agent Eval Completed",
563                cohort_id = cohort_id,
564                example_name = example.name.clone(),
565                round = round,
566                diff_score = judge_output.diff.score,
567                diff_analysis = judge_output.diff.analysis,
568                thread_score = thread.score,
569                thread_analysis = thread.analysis,
570                tool_use_counts = run_output.tool_use_counts,
571                response_count = run_output.response_count,
572                token_usage = run_output.token_usage,
573                model = model.telemetry_id(),
574                model_provider = model.provider_id().to_string(),
575                repository_url = example.base.url.clone(),
576                repository_revision = example.base.revision.clone(),
577                diagnostics_before = run_output.diagnostics_before,
578                diagnostics_after = run_output.diagnostics_after,
579                commit_id = commit_id
580            );
581        } else {
582            telemetry::event!(
583                "Agent Eval Completed",
584                cohort_id = cohort_id,
585                example_name = example.name.clone(),
586                round = round,
587                diff_score = judge_output.diff.score,
588                diff_analysis = judge_output.diff.analysis,
589                tool_use_counts = run_output.tool_use_counts,
590                response_count = run_output.response_count,
591                token_usage = run_output.token_usage,
592                model = model.telemetry_id(),
593                model_provider = model.provider_id().to_string(),
594                repository_url = example.base.url.clone(),
595                repository_revision = example.base.revision.clone(),
596                diagnostics_before = run_output.diagnostics_before,
597                diagnostics_after = run_output.diagnostics_after,
598                commit_id = commit_id
599            );
600        }
601    }
602
603    judge_result
604}