eval.rs

  1mod example;
  2
  3use client::{Client, ProxySettings, UserStore};
  4pub(crate) use example::*;
  5
  6use ::fs::RealFs;
  7use anyhow::{Result, anyhow};
  8use clap::Parser;
  9use extension::ExtensionHostProxy;
 10use futures::future;
 11use gpui::http_client::{Uri, read_proxy_from_env};
 12use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
 13use gpui_tokio::Tokio;
 14use language::LanguageRegistry;
 15use language_model::{
 16    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 17};
 18use node_runtime::{NodeBinaryOptions, NodeRuntime};
 19use project::Project;
 20use project::project_settings::ProjectSettings;
 21use prompt_store::PromptBuilder;
 22use release_channel::AppVersion;
 23use reqwest_client::ReqwestClient;
 24use settings::{Settings, SettingsStore};
 25use std::collections::HashSet;
 26use std::path::{Path, PathBuf};
 27use std::sync::Arc;
 28use std::usize;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42    /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
 43    #[arg(long, value_delimiter = ',')]
 44    languages: Option<Vec<String>>,
 45    /// How many times to run the judge on each example run.
 46    #[arg(long, default_value = "3")]
 47    judge_repetitions: u32,
 48}
 49
 50fn main() {
 51    env_logger::init();
 52
 53    let args = Args::parse();
 54    let all_available_examples = list_all_examples().unwrap();
 55    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 56
 57    let example_paths = all_available_examples
 58        .iter()
 59        .filter_map(|example_path| {
 60            let name = example_path.file_name()?.to_string_lossy();
 61            if args.examples.is_empty()
 62                || args
 63                    .examples
 64                    .iter()
 65                    .any(|name_substring| name.contains(name_substring))
 66            {
 67                Some(example_path.clone())
 68            } else {
 69                None
 70            }
 71        })
 72        .collect::<Vec<_>>();
 73
 74    let http_client = Arc::new(ReqwestClient::new());
 75    let app = Application::headless().with_http_client(http_client.clone());
 76
 77    app.run(move |cx| {
 78        let app_state = init(cx);
 79
 80        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 81
 82        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 83            registry.set_default_model(Some(model.clone()), cx);
 84        });
 85
 86        let model_provider_id = model.provider_id();
 87
 88        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 89
 90        cx.spawn(async move |cx| {
 91            authenticate.await.unwrap();
 92
 93            std::fs::create_dir_all(REPOS_DIR)?;
 94            std::fs::create_dir_all(WORKTREES_DIR)?;
 95
 96            let run_dir = Path::new(RUNS_DIR).join(format!(
 97                "{}",
 98                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 99            ));
100            std::fs::create_dir_all(&run_dir)?;
101
102            let mut examples = Vec::new();
103
104            const COLORS: [&str; 12] = [
105                "\x1b[31m", // Red
106                "\x1b[32m", // Green
107                "\x1b[33m", // Yellow
108                "\x1b[34m", // Blue
109                "\x1b[35m", // Magenta
110                "\x1b[36m", // Cyan
111                "\x1b[91m", // Bright Red
112                "\x1b[92m", // Bright Green
113                "\x1b[93m", // Bright Yellow
114                "\x1b[94m", // Bright Blue
115                "\x1b[95m", // Bright Magenta
116                "\x1b[96m", // Bright Cyan
117            ];
118
119            let mut max_name_width = 0;
120            let mut skipped = Vec::new();
121
122            for example_path in &example_paths {
123                let example = Example::load_from_directory(example_path, &run_dir)?;
124
125                if !example
126                    .base
127                    .language_extension
128                    .as_ref()
129                    .map_or(false, |lang| languages.contains(lang))
130                {
131                    skipped.push(example.name);
132                    continue;
133                }
134
135                let name_len = example.name.len();
136                if name_len > max_name_width {
137                    max_name_width = example.name.len();
138                }
139
140                examples.push(example);
141            }
142
143            println!("Skipped examples: {}\n", skipped.join(", "));
144
145            if examples.is_empty() {
146                eprintln!("Filter matched no examples");
147                return cx.update(|cx| cx.quit());
148            }
149
150            let mut repo_urls = HashSet::new();
151            let mut clone_tasks = Vec::new();
152
153            for (i, example) in examples.iter_mut().enumerate() {
154                let color = COLORS[i % COLORS.len()].to_string();
155                example.set_log_prefix_style(&color, max_name_width);
156
157                println!(
158                    "{}Logging to: {}",
159                    example.log_prefix,
160                    example.output_file_path.display()
161                );
162
163                let repo_url = example.base.url.clone();
164                if repo_urls.insert(repo_url.clone()) {
165                    let repo_path = repo_path_for_url(&repo_url);
166
167                    if !repo_path.join(".git").is_dir() {
168                        println!(
169                            "{:<width$}  < {}",
170                            "↓ Cloning",
171                            repo_url,
172                            width = max_name_width
173                        );
174
175                        let git_task = cx.spawn(async move |_cx| {
176                            std::fs::create_dir_all(&repo_path)?;
177                            run_git(&repo_path, &["init"]).await?;
178                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
179                        });
180
181                        clone_tasks.push(git_task);
182                    } else {
183                        println!(
184                            "{:<width$}  < {}",
185                            "✔︎ Already cloned",
186                            repo_url,
187                            width = max_name_width
188                        );
189
190                        let actual_origin =
191                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
192                        if actual_origin != repo_url {
193                            return Err(anyhow!(
194                                "remote origin {} does not match expected origin {}",
195                                actual_origin,
196                                repo_url,
197                            ));
198                        }
199                    }
200                }
201            }
202
203            future::join_all(clone_tasks).await;
204
205            for example in examples.iter_mut() {
206                example.setup().await?;
207            }
208
209            let judge_repetitions = args.judge_repetitions;
210            let tasks = examples
211                .into_iter()
212                .map(|example| {
213                    let app_state = app_state.clone();
214                    let model = model.clone();
215                    cx.spawn(async move |cx| {
216                        (
217                            run_example(&example, model, app_state, judge_repetitions, cx).await,
218                            example,
219                        )
220                    })
221                })
222                .collect::<Vec<_>>();
223
224            let results: Vec<(Result<Vec<Result<JudgeOutput>>>, Example)> =
225                future::join_all(tasks).await;
226
227            println!("\n\n");
228            println!("========================================");
229            println!("              EVAL RESULTS              ");
230            println!("========================================");
231            println!("");
232
233            let mut judge_scores = Vec::new();
234
235            for (result, example) in results {
236                match result {
237                    Err(err) => {
238                        println!("💥 {}{:?}", example.log_prefix, err);
239                    }
240                    Ok(judge_results) => {
241                        for judge_result in judge_results {
242                            match judge_result {
243                                Ok(judge_output) => {
244                                    const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
245
246                                    println!(
247                                        "{} {}{}",
248                                        SCORES[judge_output.score.min(5) as usize],
249                                        example.log_prefix,
250                                        judge_output.score,
251                                    );
252                                    judge_scores.push(judge_output.score);
253                                }
254                                Err(err) => {
255                                    println!("💥 {}{:?}", example.log_prefix, err);
256                                }
257                            }
258                        }
259                    }
260                }
261                println!(
262                    "{}    > {}",
263                    " ".repeat(max_name_width),
264                    example.output_file_path.display()
265                );
266            }
267
268            let score_count = judge_scores.len();
269            let average_score = judge_scores
270                .into_iter()
271                .map(|score| score as f32)
272                .sum::<f32>()
273                / (score_count as f32);
274            println!("\nAverage score: {average_score}");
275
276            cx.update(|cx| cx.quit())
277        })
278        .detach_and_log_err(cx);
279    });
280}
281
282async fn run_example(
283    example: &Example,
284    model: Arc<dyn LanguageModel>,
285    app_state: Arc<AgentAppState>,
286    judge_repetitions: u32,
287    cx: &mut AsyncApp,
288) -> Result<Vec<Result<JudgeOutput>>> {
289    cx.update(|cx| example.run(model.clone(), app_state, cx))?
290        .await?;
291    let diff = example.repository_diff().await?;
292
293    let judge_tasks = (0..judge_repetitions)
294        .map(|round| example.judge(model.clone(), diff.clone(), round, cx))
295        .collect::<Vec<_>>();
296
297    Ok(future::join_all(judge_tasks).await)
298}
299
300fn list_all_examples() -> Result<Vec<PathBuf>> {
301    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
302    let entries = std::fs::read_dir(path).unwrap();
303    let mut result_paths = Vec::new();
304    for entry in entries {
305        let entry = entry?;
306        let path = entry.path();
307        if path.is_dir() {
308            result_paths.push(path);
309        }
310    }
311    Ok(result_paths)
312}
313
314/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
315pub struct AgentAppState {
316    pub languages: Arc<LanguageRegistry>,
317    pub client: Arc<Client>,
318    pub user_store: Entity<UserStore>,
319    pub fs: Arc<dyn fs::Fs>,
320    pub node_runtime: NodeRuntime,
321
322    // Additional fields not present in `workspace::AppState`.
323    pub prompt_builder: Arc<PromptBuilder>,
324}
325
326pub fn init(cx: &mut App) -> Arc<AgentAppState> {
327    release_channel::init(SemanticVersion::default(), cx);
328    gpui_tokio::init(cx);
329
330    let mut settings_store = SettingsStore::new(cx);
331    settings_store
332        .set_default_settings(settings::default_settings().as_ref(), cx)
333        .unwrap();
334    cx.set_global(settings_store);
335    client::init_settings(cx);
336
337    // Set User-Agent so we can download language servers from GitHub
338    let user_agent = format!(
339        "Zed/{} ({}; {})",
340        AppVersion::global(cx),
341        std::env::consts::OS,
342        std::env::consts::ARCH
343    );
344    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
345    let proxy_url = proxy_str
346        .as_ref()
347        .and_then(|input| input.parse::<Uri>().ok())
348        .or_else(read_proxy_from_env);
349    let http = {
350        let _guard = Tokio::handle(cx).enter();
351
352        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
353            .expect("could not start HTTP client")
354    };
355    cx.set_http_client(Arc::new(http));
356
357    Project::init_settings(cx);
358
359    let client = Client::production(cx);
360    cx.set_http_client(client.http_client().clone());
361
362    let git_binary_path = None;
363    let fs = Arc::new(RealFs::new(
364        git_binary_path,
365        cx.background_executor().clone(),
366    ));
367
368    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
369    languages.set_language_server_download_dir(paths::languages_dir().clone());
370    let languages = Arc::new(languages);
371
372    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
373
374    extension::init(cx);
375
376    let (tx, rx) = async_watch::channel(None);
377    cx.observe_global::<SettingsStore>(move |cx| {
378        let settings = &ProjectSettings::get_global(cx).node;
379        let options = NodeBinaryOptions {
380            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
381            allow_binary_download: true,
382            use_paths: settings.path.as_ref().map(|node_path| {
383                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
384                let npm_path = settings
385                    .npm_path
386                    .as_ref()
387                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
388                (
389                    node_path.clone(),
390                    npm_path.unwrap_or_else(|| {
391                        let base_path = PathBuf::new();
392                        node_path.parent().unwrap_or(&base_path).join("npm")
393                    }),
394                )
395            }),
396        };
397        tx.send(Some(options)).log_err();
398    })
399    .detach();
400    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
401
402    let extension_host_proxy = ExtensionHostProxy::global(cx);
403
404    language::init(cx);
405    language_extension::init(extension_host_proxy.clone(), languages.clone());
406    language_model::init(client.clone(), cx);
407    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
408    languages::init(languages.clone(), node_runtime.clone(), cx);
409    assistant_tools::init(client.http_client().clone(), cx);
410    context_server::init(cx);
411    let stdout_is_a_pty = false;
412    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
413    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
414
415    SettingsStore::update_global(cx, |store, cx| {
416        store.set_user_settings(include_str!("../runner_settings.json"), cx)
417    })
418    .unwrap();
419
420    Arc::new(AgentAppState {
421        languages,
422        client,
423        user_store,
424        fs,
425        node_runtime,
426        prompt_builder,
427    })
428}
429
430pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
431    let model_registry = LanguageModelRegistry::read_global(cx);
432    let model = model_registry
433        .available_models(cx)
434        .find(|model| model.id().0 == model_name);
435
436    let Some(model) = model else {
437        return Err(anyhow!(
438            "No language model named {} was available. Available models: {}",
439            model_name,
440            model_registry
441                .available_models(cx)
442                .map(|model| model.id().0.clone())
443                .collect::<Vec<_>>()
444                .join(", ")
445        ));
446    };
447
448    Ok(model)
449}
450
451pub fn authenticate_model_provider(
452    provider_id: LanguageModelProviderId,
453    cx: &mut App,
454) -> Task<std::result::Result<(), AuthenticateError>> {
455    let model_registry = LanguageModelRegistry::read_global(cx);
456    let model_provider = model_registry.provider(&provider_id).unwrap();
457    model_provider.authenticate(cx)
458}