eval.rs

  1mod example;
  2mod ids;
  3
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6use telemetry;
  7
  8use ::fs::RealFs;
  9use anyhow::{Result, anyhow};
 10use clap::Parser;
 11use extension::ExtensionHostProxy;
 12use futures::future;
 13use gpui::http_client::{Uri, read_proxy_from_env};
 14use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 15use gpui_tokio::Tokio;
 16use language::LanguageRegistry;
 17use language_model::{
 18    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 19};
 20use node_runtime::{NodeBinaryOptions, NodeRuntime};
 21use project::Project;
 22use project::project_settings::ProjectSettings;
 23use prompt_store::PromptBuilder;
 24use release_channel::AppVersion;
 25use reqwest_client::ReqwestClient;
 26use settings::{Settings, SettingsStore};
 27use std::collections::HashSet;
 28use std::path::{Path, PathBuf};
 29use std::sync::Arc;
 30use std::usize;
 31use util::ResultExt as _;
 32
 33pub const RUNS_DIR: &str = "./crates/eval/runs";
 34
 35#[derive(Parser, Debug)]
 36#[command(name = "eval", disable_version_flag = true)]
 37struct Args {
 38    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 39    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 40    examples: Vec<String>,
 41    /// Model to use (default: "claude-3-7-sonnet-latest")
 42    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 43    model: String,
 44    #[arg(long, value_delimiter = ',')]
 45    languages: Option<Vec<String>>,
 46    /// How many times to run the judge on each example run.
 47    #[arg(long, default_value = "3")]
 48    judge_repetitions: u32,
 49}
 50
 51fn main() {
 52    env_logger::init();
 53
 54    let args = Args::parse();
 55    let all_available_examples = list_all_examples().unwrap();
 56    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 57
 58    let example_paths = all_available_examples
 59        .iter()
 60        .filter_map(|example_path| {
 61            let name = example_path.file_name()?.to_string_lossy();
 62            if args.examples.is_empty()
 63                || args
 64                    .examples
 65                    .iter()
 66                    .any(|name_substring| name.contains(name_substring))
 67            {
 68                Some(example_path.clone())
 69            } else {
 70                None
 71            }
 72        })
 73        .collect::<Vec<_>>();
 74
 75    let http_client = Arc::new(ReqwestClient::new());
 76    let app = Application::headless().with_http_client(http_client.clone());
 77
 78    app.run(move |cx| {
 79        let app_state = init(cx);
 80
 81        let system_id = ids::get_or_create_id(&ids::eval_system_id_path()).ok();
 82        let installation_id = ids::get_or_create_id(&ids::eval_installation_id_path()).ok();
 83        let session_id = uuid::Uuid::new_v4().to_string();
 84
 85        app_state
 86            .client
 87            .telemetry()
 88            .start(system_id, installation_id, session_id, cx);
 89
 90        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 91
 92        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 93            registry.set_default_model(Some(model.clone()), cx);
 94        });
 95
 96        let model_provider_id = model.provider_id();
 97
 98        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 99
100        cx.spawn(async move |cx| {
101            authenticate.await.unwrap();
102
103            std::fs::create_dir_all(REPOS_DIR)?;
104            std::fs::create_dir_all(WORKTREES_DIR)?;
105
106            let run_dir = Path::new(RUNS_DIR).join(format!(
107                "{}",
108                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
109            ));
110            std::fs::create_dir_all(&run_dir)?;
111
112            let mut examples = Vec::new();
113
114            const COLORS: [&str; 12] = [
115                "\x1b[31m", // Red
116                "\x1b[32m", // Green
117                "\x1b[33m", // Yellow
118                "\x1b[34m", // Blue
119                "\x1b[35m", // Magenta
120                "\x1b[36m", // Cyan
121                "\x1b[91m", // Bright Red
122                "\x1b[92m", // Bright Green
123                "\x1b[93m", // Bright Yellow
124                "\x1b[94m", // Bright Blue
125                "\x1b[95m", // Bright Magenta
126                "\x1b[96m", // Bright Cyan
127            ];
128
129            let mut max_name_width = 0;
130            let mut skipped = Vec::new();
131
132            for example_path in &example_paths {
133                let example = Example::load_from_directory(example_path, &run_dir)?;
134
135                if !example
136                    .base
137                    .language_extension
138                    .as_ref()
139                    .map_or(false, |lang| languages.contains(lang))
140                {
141                    skipped.push(example.name);
142                    continue;
143                }
144
145                let name_len = example.name.len();
146                if name_len > max_name_width {
147                    max_name_width = example.name.len();
148                }
149
150                examples.push(example);
151            }
152
153            println!("Skipped examples: {}\n", skipped.join(", "));
154
155            if examples.is_empty() {
156                eprintln!("Filter matched no examples");
157                return cx.update(|cx| cx.quit());
158            }
159
160            let mut repo_urls = HashSet::new();
161            let mut clone_tasks = Vec::new();
162
163            for (i, example) in examples.iter_mut().enumerate() {
164                let color = COLORS[i % COLORS.len()].to_string();
165                example.set_log_prefix_style(&color, max_name_width);
166
167                println!(
168                    "{}Logging to: {}",
169                    example.log_prefix,
170                    example.output_file_path.display()
171                );
172
173                let repo_url = example.base.url.clone();
174                if repo_urls.insert(repo_url.clone()) {
175                    let repo_path = repo_path_for_url(&repo_url);
176
177                    if !repo_path.join(".git").is_dir() {
178                        println!(
179                            "{:<width$}  < {}",
180                            "↓ Cloning",
181                            repo_url,
182                            width = max_name_width
183                        );
184
185                        let git_task = cx.spawn(async move |_cx| {
186                            std::fs::create_dir_all(&repo_path)?;
187                            run_git(&repo_path, &["init"]).await?;
188                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
189                        });
190
191                        clone_tasks.push(git_task);
192                    } else {
193                        println!(
194                            "{:<width$}  < {}",
195                            "✔︎ Already cloned",
196                            repo_url,
197                            width = max_name_width
198                        );
199
200                        let actual_origin =
201                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
202                        if actual_origin != repo_url {
203                            return Err(anyhow!(
204                                "remote origin {} does not match expected origin {}",
205                                actual_origin,
206                                repo_url,
207                            ));
208                        }
209                    }
210                }
211            }
212
213            future::join_all(clone_tasks).await;
214
215            for example in examples.iter_mut() {
216                example.setup().await?;
217            }
218
219            let judge_repetitions = args.judge_repetitions;
220            let tasks = examples
221                .into_iter()
222                .map(|example| {
223                    let app_state = app_state.clone();
224                    let model = model.clone();
225                    cx.spawn(async move |cx| {
226                        (
227                            run_example(&example, model, app_state, judge_repetitions, cx).await,
228                            example,
229                        )
230                    })
231                })
232                .collect::<Vec<_>>();
233
234            let results: Vec<(Result<Vec<Result<JudgeOutput>>>, Example)> =
235                future::join_all(tasks).await;
236
237            println!("\n\n");
238            println!("========================================");
239            println!("              EVAL RESULTS              ");
240            println!("========================================");
241            println!("");
242
243            let mut judge_scores = Vec::new();
244
245            for (result, example) in results {
246                match result {
247                    Err(err) => {
248                        println!("💥 {}{:?}", example.log_prefix, err);
249                    }
250                    Ok(judge_results) => {
251                        for judge_result in judge_results {
252                            match judge_result {
253                                Ok(judge_output) => {
254                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
255
256                                    println!(
257                                        "{} {}{}",
258                                        SCORES[judge_output.score.min(5) as usize],
259                                        example.log_prefix,
260                                        judge_output.score,
261                                    );
262                                    judge_scores.push(judge_output.score);
263                                }
264                                Err(err) => {
265                                    println!("💥 {}{:?}", example.log_prefix, err);
266                                }
267                            }
268                        }
269                    }
270                }
271                println!(
272                    "{}    > {}",
273                    " ".repeat(max_name_width),
274                    example.output_file_path.display()
275                );
276            }
277
278            let score_count = judge_scores.len();
279            let average_score = judge_scores
280                .into_iter()
281                .map(|score| score as f32)
282                .sum::<f32>()
283                / (score_count as f32);
284            println!("\nAverage score: {average_score}");
285
286            std::thread::sleep(std::time::Duration::from_secs(2));
287
288            // Flush telemetry events before exiting
289            app_state.client.telemetry().flush_events();
290
291            cx.update(|cx| cx.quit())
292        })
293        .detach_and_log_err(cx);
294    });
295}
296
297async fn run_example(
298    example: &Example,
299    model: Arc<dyn LanguageModel>,
300    app_state: Arc<AgentAppState>,
301    judge_repetitions: u32,
302    cx: &mut AsyncApp,
303) -> Result<Vec<Result<JudgeOutput>>> {
304    let run_output = cx
305        .update(|cx| example.run(model.clone(), app_state.clone(), cx))?
306        .await?;
307    let diff = example.repository_diff().await?;
308
309    // Run judge for each repetition
310    let mut results = Vec::new();
311    for round in 0..judge_repetitions {
312        let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await;
313
314        // Log telemetry for this judge result
315        if let Ok(judge_output) = &judge_result {
316            let cohort_id = example
317                .output_file_path
318                .parent()
319                .and_then(|p| p.file_name())
320                .map(|name| name.to_string_lossy().to_string())
321                .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string());
322
323            telemetry::event!(
324                "Agent Eval Completed",
325                cohort_id = cohort_id,
326                example_name = example.name.clone(),
327                round = round,
328                score = judge_output.score,
329                analysis = judge_output.analysis,
330                tool_use_counts = run_output.tool_use_counts,
331                response_count = run_output.response_count,
332                token_usage = run_output.token_usage,
333                model = model.telemetry_id(),
334                model_provider = model.provider_id().to_string(),
335                repository_url = example.base.url.clone(),
336                repository_revision = example.base.revision.clone(),
337                diagnostics_summary = run_output.diagnostics
338            );
339        }
340
341        results.push(judge_result);
342    }
343
344    app_state.client.telemetry().flush_events();
345
346    Ok(results)
347}
348
349fn list_all_examples() -> Result<Vec<PathBuf>> {
350    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
351    let entries = std::fs::read_dir(path).unwrap();
352    let mut result_paths = Vec::new();
353    for entry in entries {
354        let entry = entry?;
355        let path = entry.path();
356        if path.is_dir() {
357            result_paths.push(path);
358        }
359    }
360    Ok(result_paths)
361}
362
363/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
364pub struct AgentAppState {
365    pub languages: Arc<LanguageRegistry>,
366    pub client: Arc<Client>,
367    pub user_store: Entity<UserStore>,
368    pub fs: Arc<dyn fs::Fs>,
369    pub node_runtime: NodeRuntime,
370
371    // Additional fields not present in `workspace::AppState`.
372    pub prompt_builder: Arc<PromptBuilder>,
373}
374
375pub fn init(cx: &mut App) -> Arc<AgentAppState> {
376    release_channel::init(SemanticVersion::default(), cx);
377    gpui_tokio::init(cx);
378
379    let mut settings_store = SettingsStore::new(cx);
380    settings_store
381        .set_default_settings(settings::default_settings().as_ref(), cx)
382        .unwrap();
383    cx.set_global(settings_store);
384    client::init_settings(cx);
385
386    // Set User-Agent so we can download language servers from GitHub
387    let user_agent = format!(
388        "Zed/{} ({}; {})",
389        AppVersion::global(cx),
390        std::env::consts::OS,
391        std::env::consts::ARCH
392    );
393    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
394    let proxy_url = proxy_str
395        .as_ref()
396        .and_then(|input| input.parse::<Uri>().ok())
397        .or_else(read_proxy_from_env);
398    let http = {
399        let _guard = Tokio::handle(cx).enter();
400
401        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
402            .expect("could not start HTTP client")
403    };
404    cx.set_http_client(Arc::new(http));
405
406    Project::init_settings(cx);
407
408    let client = Client::production(cx);
409    cx.set_http_client(client.http_client().clone());
410
411    let git_binary_path = None;
412    let fs = Arc::new(RealFs::new(
413        git_binary_path,
414        cx.background_executor().clone(),
415    ));
416
417    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
418    languages.set_language_server_download_dir(paths::languages_dir().clone());
419    let languages = Arc::new(languages);
420
421    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
422
423    extension::init(cx);
424
425    let (tx, rx) = async_watch::channel(None);
426    cx.observe_global::<SettingsStore>(move |cx| {
427        let settings = &ProjectSettings::get_global(cx).node;
428        let options = NodeBinaryOptions {
429            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
430            allow_binary_download: true,
431            use_paths: settings.path.as_ref().map(|node_path| {
432                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
433                let npm_path = settings
434                    .npm_path
435                    .as_ref()
436                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
437                (
438                    node_path.clone(),
439                    npm_path.unwrap_or_else(|| {
440                        let base_path = PathBuf::new();
441                        node_path.parent().unwrap_or(&base_path).join("npm")
442                    }),
443                )
444            }),
445        };
446        tx.send(Some(options)).log_err();
447    })
448    .detach();
449    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
450
451    let extension_host_proxy = ExtensionHostProxy::global(cx);
452
453    language::init(cx);
454    language_extension::init(extension_host_proxy.clone(), languages.clone());
455    language_model::init(client.clone(), cx);
456    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
457    languages::init(languages.clone(), node_runtime.clone(), cx);
458    assistant_tools::init(client.http_client().clone(), cx);
459    context_server::init(cx);
460    let stdout_is_a_pty = false;
461    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
462    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
463
464    SettingsStore::update_global(cx, |store, cx| {
465        store.set_user_settings(include_str!("../runner_settings.json"), cx)
466    })
467    .unwrap();
468
469    Arc::new(AgentAppState {
470        languages,
471        client,
472        user_store,
473        fs,
474        node_runtime,
475        prompt_builder,
476    })
477}
478
479pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
480    let model_registry = LanguageModelRegistry::read_global(cx);
481    let model = model_registry
482        .available_models(cx)
483        .find(|model| model.id().0 == model_name);
484
485    let Some(model) = model else {
486        return Err(anyhow!(
487            "No language model named {} was available. Available models: {}",
488            model_name,
489            model_registry
490                .available_models(cx)
491                .map(|model| model.id().0.clone())
492                .collect::<Vec<_>>()
493                .join(", ")
494        ));
495    };
496
497    Ok(model)
498}
499
500pub fn authenticate_model_provider(
501    provider_id: LanguageModelProviderId,
502    cx: &mut App,
503) -> Task<std::result::Result<(), AuthenticateError>> {
504    let model_registry = LanguageModelRegistry::read_global(cx);
505    let model_provider = model_registry.provider(&provider_id).unwrap();
506    model_provider.authenticate(cx)
507}