eval.rs

  1mod example;
  2mod ids;
  3
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6use telemetry;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use extension::ExtensionHostProxy;
 12use futures::future;
 13use futures::stream::StreamExt;
 14use gpui::http_client::{Uri, read_proxy_from_env};
 15use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 16use gpui_tokio::Tokio;
 17use language::LanguageRegistry;
 18use language_model::{
 19    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 20};
 21use node_runtime::{NodeBinaryOptions, NodeRuntime};
 22use project::Project;
 23use project::project_settings::ProjectSettings;
 24use prompt_store::PromptBuilder;
 25use release_channel::AppVersion;
 26use reqwest_client::ReqwestClient;
 27use settings::{Settings, SettingsStore};
 28use std::collections::HashSet;
 29use std::path::{Path, PathBuf};
 30use std::sync::Arc;
 31use std::usize;
 32use util::ResultExt as _;
 33
 34pub const RUNS_DIR: &str = "./crates/eval/runs";
 35
 36#[derive(Parser, Debug)]
 37#[command(name = "eval", disable_version_flag = true)]
 38struct Args {
 39    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 40    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 41    examples: Vec<String>,
 42    /// Model to use (default: "claude-3-7-sonnet-latest")
 43    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 44    model: String,
 45    #[arg(long, value_delimiter = ',')]
 46    languages: Option<Vec<String>>,
 47    /// How many times to run the judge on each example run.
 48    #[arg(long, default_value = "3")]
 49    judge_repetitions: u32,
 50    /// Maximum number of examples to run concurrently.
 51    #[arg(long, default_value = "10")]
 52    concurrency: usize,
 53}
 54
 55fn main() {
 56    env_logger::init();
 57
 58    let args = Args::parse();
 59    let all_available_examples = list_all_examples().unwrap();
 60    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 61
 62    let example_paths = all_available_examples
 63        .iter()
 64        .filter_map(|example_path| {
 65            let name = example_path.file_name()?.to_string_lossy();
 66            if args.examples.is_empty()
 67                || args
 68                    .examples
 69                    .iter()
 70                    .any(|name_substring| name.contains(name_substring))
 71            {
 72                Some(example_path.clone())
 73            } else {
 74                None
 75            }
 76        })
 77        .collect::<Vec<_>>();
 78
 79    let http_client = Arc::new(ReqwestClient::new());
 80    let app = Application::headless().with_http_client(http_client.clone());
 81
 82    app.run(move |cx| {
 83        let app_state = init(cx);
 84
 85        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 86        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 87        let session_id = uuid::Uuid::new_v4().to_string();
 88
 89        app_state
 90            .client
 91            .telemetry()
 92            .start(system_id, installation_id, session_id, cx);
 93
 94        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 95
 96        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 97            registry.set_default_model(Some(model.clone()), cx);
 98        });
 99
100        let model_provider_id = model.provider_id();
101
102        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
103
104        cx.spawn(async move |cx| {
105            authenticate.await.unwrap();
106
107            std::fs::create_dir_all(REPOS_DIR)?;
108            std::fs::create_dir_all(WORKTREES_DIR)?;
109
110            let run_dir = Path::new(RUNS_DIR).join(format!(
111                "{}",
112                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
113            ));
114            std::fs::create_dir_all(&run_dir)?;
115
116            let mut examples = Vec::new();
117
118            const COLORS: [&str; 12] = [
119                "\x1b[31m", // Red
120                "\x1b[32m", // Green
121                "\x1b[33m", // Yellow
122                "\x1b[34m", // Blue
123                "\x1b[35m", // Magenta
124                "\x1b[36m", // Cyan
125                "\x1b[91m", // Bright Red
126                "\x1b[92m", // Bright Green
127                "\x1b[93m", // Bright Yellow
128                "\x1b[94m", // Bright Blue
129                "\x1b[95m", // Bright Magenta
130                "\x1b[96m", // Bright Cyan
131            ];
132
133            let mut max_name_width = 0;
134            let mut skipped = Vec::new();
135
136            for example_path in &example_paths {
137                let example = Example::load_from_directory(example_path, &run_dir)?;
138
139                if !example
140                    .base
141                    .language_extension
142                    .as_ref()
143                    .map_or(false, |lang| languages.contains(lang))
144                {
145                    skipped.push(example.name);
146                    continue;
147                }
148
149                let name_len = example.name.len();
150                if name_len > max_name_width {
151                    max_name_width = example.name.len();
152                }
153
154                examples.push(example);
155            }
156
157            println!("Skipped examples: {}\n", skipped.join(", "));
158
159            if examples.is_empty() {
160                eprintln!("Filter matched no examples");
161                return cx.update(|cx| cx.quit());
162            }
163
164            let mut repo_urls = HashSet::new();
165            let mut clone_tasks = Vec::new();
166
167            for (i, example) in examples.iter_mut().enumerate() {
168                let color = COLORS[i % COLORS.len()].to_string();
169                example.set_log_prefix_style(&color, max_name_width);
170
171                println!(
172                    "{}Logging to: {}",
173                    example.log_prefix,
174                    example.output_file_path.display()
175                );
176
177                let repo_url = example.base.url.clone();
178                if repo_urls.insert(repo_url.clone()) {
179                    let repo_path = repo_path_for_url(&repo_url);
180
181                    if !repo_path.join(".git").is_dir() {
182                        println!(
183                            "{:<width$}  < {}",
184                            "↓ Cloning",
185                            repo_url,
186                            width = max_name_width
187                        );
188
189                        let git_task = cx.spawn(async move |_cx| {
190                            std::fs::create_dir_all(&repo_path)?;
191                            run_git(&repo_path, &["init"]).await?;
192                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
193                        });
194
195                        clone_tasks.push(git_task);
196                    } else {
197                        println!(
198                            "{:<width$}  < {}",
199                            "✔︎ Already cloned",
200                            repo_url,
201                            width = max_name_width
202                        );
203
204                        let actual_origin =
205                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
206                        if actual_origin != repo_url {
207                            return Err(anyhow!(
208                                "remote origin {} does not match expected origin {}",
209                                actual_origin,
210                                repo_url,
211                            ));
212                        }
213                    }
214                }
215            }
216
217            future::join_all(clone_tasks).await;
218
219            for example in examples.iter_mut() {
220                example.setup().await?;
221            }
222
223            let judge_repetitions = args.judge_repetitions;
224            let concurrency = args.concurrency;
225
226            let tasks = examples
227                .into_iter()
228                .map(|example| {
229                    let app_state = app_state.clone();
230                    let model = model.clone();
231                    cx.spawn(async move |cx| {
232                        let result =
233                            run_example(&example, model, app_state, judge_repetitions, cx).await;
234                        (result, example)
235                    })
236                })
237                .collect::<Vec<_>>();
238
239            let results = futures::stream::iter(tasks)
240                .buffer_unordered(concurrency)
241                .collect::<Vec<(Result<Vec<Result<JudgeOutput>>>, Example)>>()
242                .await;
243
244            println!("\n\n");
245            println!("========================================");
246            println!("              EVAL RESULTS              ");
247            println!("========================================");
248            println!("");
249
250            let mut judge_scores = Vec::new();
251
252            for (result, example) in results {
253                match result {
254                    Err(err) => {
255                        println!("💥 {}{:?}", example.log_prefix, err);
256                    }
257                    Ok(judge_results) => {
258                        for judge_result in judge_results {
259                            match judge_result {
260                                Ok(judge_output) => {
261                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
262
263                                    println!(
264                                        "{} {}{}",
265                                        SCORES[judge_output.score.min(5) as usize],
266                                        example.log_prefix,
267                                        judge_output.score,
268                                    );
269                                    judge_scores.push(judge_output.score);
270                                }
271                                Err(err) => {
272                                    println!("💥 {}{:?}", example.log_prefix, err);
273                                }
274                            }
275                        }
276                    }
277                }
278                println!(
279                    "{}    > {}",
280                    " ".repeat(max_name_width),
281                    example.output_file_path.display()
282                );
283            }
284
285            let score_count = judge_scores.len();
286            let average_score = judge_scores
287                .into_iter()
288                .map(|score| score as f32)
289                .sum::<f32>()
290                / (score_count as f32);
291            println!("\nAverage score: {average_score}");
292
293            std::thread::sleep(std::time::Duration::from_secs(2));
294
295            // Flush telemetry events before exiting
296            app_state.client.telemetry().flush_events();
297
298            cx.update(|cx| cx.quit())
299        })
300        .detach_and_log_err(cx);
301    });
302}
303
304async fn run_example(
305    example: &Example,
306    model: Arc<dyn LanguageModel>,
307    app_state: Arc<AgentAppState>,
308    judge_repetitions: u32,
309    cx: &mut AsyncApp,
310) -> Result<Vec<Result<JudgeOutput>>> {
311    let run_output = cx
312        .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
313        .await?;
314    let diff = example.repository_diff().await?;
315
316    // Run judge for each repetition
317    let mut results = Vec::new();
318    for round in 0..judge_repetitions {
319        let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
320
321        // Log telemetry for this judge result
322        if let Ok(judge_output) = &judge_result {
323            let cohort_id = example
324                .output_file_path
325                .parent()
326                .and_then(|p| p.file_name())
327                .map(|name| name.to_string_lossy().to_string())
328                .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
329
330            telemetry::event!(
331                "Agent Eval Completed",
332                cohort_id = cohort_id,
333                example_name = example.name.clone(),
334                round = round,
335                score = judge_output.score,
336                analysis = judge_output.analysis,
337                tool_use_counts = run_output.tool_use_counts,
338                response_count = run_output.response_count,
339                token_usage = run_output.token_usage,
340                model = model.telemetry_id(),
341                model_provider = model.provider_id().to_string(),
342                repository_url = example.base.url.clone(),
343                repository_revision = example.base.revision.clone(),
344                diagnostics_summary = run_output.diagnostics
345            );
346        }
347
348        results.push(judge_result);
349    }
350
351    app_state.client.telemetry().flush_events();
352
353    Ok(results)
354}
355
356fn list_all_examples() -> Result<Vec<PathBuf>> {
357    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
358    let entries = std::fs::read_dir(path).unwrap();
359    let mut result_paths = Vec::new();
360    for entry in entries {
361        let entry = entry?;
362        let path = entry.path();
363        if path.is_dir() {
364            result_paths.push(path);
365        }
366    }
367    Ok(result_paths)
368}
369
370/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
371pub struct AgentAppState {
372    pub languages: Arc<LanguageRegistry>,
373    pub client: Arc<Client>,
374    pub user_store: Entity<UserStore>,
375    pub fs: Arc<dyn fs::Fs>,
376    pub node_runtime: NodeRuntime,
377
378    // Additional fields not present in `workspace::AppState`.
379    pub prompt_builder: Arc<PromptBuilder>,
380}
381
382pub fn init(cx: &mut App) -> Arc<AgentAppState> {
383    release_channel::init(SemanticVersion::default(), cx);
384    gpui_tokio::init(cx);
385
386    let mut settings_store = SettingsStore::new(cx);
387    settings_store
388        .set_default_settings(settings::default_settings().as_ref(), cx)
389        .unwrap();
390    cx.set_global(settings_store);
391    client::init_settings(cx);
392
393    // Set User-Agent so we can download language servers from GitHub
394    let user_agent = format!(
395        "Zed/{} ({}; {})",
396        AppVersion::global(cx),
397        std::env::consts::OS,
398        std::env::consts::ARCH
399    );
400    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
401    let proxy_url = proxy_str
402        .as_ref()
403        .and_then(|input| input.parse::<Uri>().ok())
404        .or_else(read_proxy_from_env);
405    let http = {
406        let _guard = Tokio::handle(cx).enter();
407
408        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
409            .expect("could not start HTTP client")
410    };
411    cx.set_http_client(Arc::new(http));
412
413    Project::init_settings(cx);
414
415    let client = Client::production(cx);
416    cx.set_http_client(client.http_client().clone());
417
418    let git_binary_path = None;
419    let fs = Arc::new(RealFs::new(
420        git_binary_path,
421        cx.background_executor().clone(),
422    ));
423
424    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
425    languages.set_language_server_download_dir(paths::languages_dir().clone());
426    let languages = Arc::new(languages);
427
428    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
429
430    extension::init(cx);
431
432    let (tx, rx) = async_watch::channel(None);
433    cx.observe_global::<SettingsStore>(move |cx| {
434        let settings = &ProjectSettings::get_global(cx).node;
435        let options = NodeBinaryOptions {
436            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
437            allow_binary_download: true,
438            use_paths: settings.path.as_ref().map(|node_path| {
439                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
440                let npm_path = settings
441                    .npm_path
442                    .as_ref()
443                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
444                (
445                    node_path.clone(),
446                    npm_path.unwrap_or_else(|| {
447                        let base_path = PathBuf::new();
448                        node_path.parent().unwrap_or(&base_path).join("npm")
449                    }),
450                )
451            }),
452        };
453        tx.send(Some(options)).log_err();
454    })
455    .detach();
456    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
457
458    let extension_host_proxy = ExtensionHostProxy::global(cx);
459
460    language::init(cx);
461    language_extension::init(extension_host_proxy.clone(), languages.clone());
462    language_model::init(client.clone(), cx);
463    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
464    languages::init(languages.clone(), node_runtime.clone(), cx);
465    assistant_tools::init(client.http_client().clone(), cx);
466    context_server::init(cx);
467    let stdout_is_a_pty = false;
468    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
469    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
470
471    SettingsStore::update_global(cx, |store, cx| {
472        store.set_user_settings(include_str!("../runner_settings.json"), cx)
473    })
474    .unwrap();
475
476    Arc::new(AgentAppState {
477        languages,
478        client,
479        user_store,
480        fs,
481        node_runtime,
482        prompt_builder,
483    })
484}
485
486pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
487    let model_registry = LanguageModelRegistry::read_global(cx);
488    let model = model_registry
489        .available_models(cx)
490        .find(|model| model.id().0 == model_name);
491
492    let Some(model) = model else {
493        return Err(anyhow!(
494            "No language model named {} was available. Available models: {}",
495            model_name,
496            model_registry
497                .available_models(cx)
498                .map(|model| model.id().0.clone())
499                .collect::<Vec<_>>()
500                .join(", ")
501        ));
502    };
503
504    Ok(model)
505}
506
507pub fn authenticate_model_provider(
508    provider_id: LanguageModelProviderId,
509    cx: &mut App,
510) -> Task<std::result::Result<(), AuthenticateError>> {
511    let model_registry = LanguageModelRegistry::read_global(cx);
512    let model_provider = model_registry.provider(&provider_id).unwrap();
513    model_provider.authenticate(cx)
514}