eval.rs

  1mod example;
  2
  3use assistant_settings::AssistantSettings;
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6
  7use ::fs::RealFs;
  8use anyhow::{Result, anyhow};
  9use clap::Parser;
 10use extension::ExtensionHostProxy;
 11use futures::future;
 12use gpui::http_client::{Uri, read_proxy_from_env};
 13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
 14use gpui_tokio::Tokio;
 15use language::LanguageRegistry;
 16use language_model::{
 17    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 18};
 19use node_runtime::{NodeBinaryOptions, NodeRuntime};
 20use project::Project;
 21use project::project_settings::ProjectSettings;
 22use prompt_store::PromptBuilder;
 23use release_channel::AppVersion;
 24use reqwest_client::ReqwestClient;
 25use settings::{Settings, SettingsStore};
 26use std::collections::HashSet;
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use std::usize;
 30use util::ResultExt as _;
 31
 32pub const RUNS_DIR: &str = "./crates/eval/runs";
 33
 34#[derive(Parser, Debug)]
 35#[command(name = "eval", disable_version_flag = true)]
 36struct Args {
 37    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 38    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 39    examples: Vec<String>,
 40    /// Model to use (default: "claude-3-7-sonnet-latest")
 41    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 42    model: String,
 43    /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
 44    #[arg(long, value_delimiter = ',')]
 45    languages: Option<Vec<String>>,
 46}
 47
 48fn main() {
 49    env_logger::init();
 50
 51    let args = Args::parse();
 52    let all_available_examples = list_all_examples().unwrap();
 53    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 54
 55    let example_paths = all_available_examples
 56        .iter()
 57        .filter_map(|example_path| {
 58            let name = example_path.file_name()?.to_string_lossy();
 59            if args.examples.is_empty()
 60                || args
 61                    .examples
 62                    .iter()
 63                    .any(|name_substring| name.contains(name_substring))
 64            {
 65                Some(example_path.clone())
 66            } else {
 67                None
 68            }
 69        })
 70        .collect::<Vec<_>>();
 71
 72    let http_client = Arc::new(ReqwestClient::new());
 73    let app = Application::headless().with_http_client(http_client.clone());
 74
 75    app.run(move |cx| {
 76        let app_state = init(cx);
 77
 78        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 79
 80        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 81            registry.set_default_model(Some(model.clone()), cx);
 82        });
 83
 84        let model_provider_id = model.provider_id();
 85
 86        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 87
 88        cx.spawn(async move |cx| {
 89            authenticate.await.unwrap();
 90
 91            std::fs::create_dir_all(REPOS_DIR)?;
 92            std::fs::create_dir_all(WORKTREES_DIR)?;
 93
 94            let run_dir = Path::new(RUNS_DIR).join(format!(
 95                "{}",
 96                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 97            ));
 98            std::fs::create_dir_all(&run_dir)?;
 99
100            let mut examples = Vec::new();
101
102            const COLORS: [&str; 12] = [
103                "\x1b[31m", // Red
104                "\x1b[32m", // Green
105                "\x1b[33m", // Yellow
106                "\x1b[34m", // Blue
107                "\x1b[35m", // Magenta
108                "\x1b[36m", // Cyan
109                "\x1b[91m", // Bright Red
110                "\x1b[92m", // Bright Green
111                "\x1b[93m", // Bright Yellow
112                "\x1b[94m", // Bright Blue
113                "\x1b[95m", // Bright Magenta
114                "\x1b[96m", // Bright Cyan
115            ];
116
117            let mut max_name_width = 0;
118            let mut skipped = Vec::new();
119
120            for example_path in &example_paths {
121                let example = Example::load_from_directory(example_path, &run_dir)?;
122
123                if !example
124                    .base
125                    .language_extension
126                    .as_ref()
127                    .map_or(false, |lang| languages.contains(lang))
128                {
129                    skipped.push(example.name);
130                    continue;
131                }
132
133                let name_len = example.name.len();
134                if name_len > max_name_width {
135                    max_name_width = example.name.len();
136                }
137
138                examples.push(example);
139            }
140
141            println!("Skipped examples: {}\n", skipped.join(", "));
142
143            if examples.is_empty() {
144                eprintln!("Filter matched no examples");
145                return cx.update(|cx| cx.quit());
146            }
147
148            let mut repo_urls = HashSet::new();
149            let mut clone_tasks = Vec::new();
150
151            for (i, example) in examples.iter_mut().enumerate() {
152                let color = COLORS[i % COLORS.len()].to_string();
153                example.set_log_prefix_style(&color, max_name_width);
154
155                println!(
156                    "{}Logging to: {}",
157                    example.log_prefix,
158                    example.output_file_path.display()
159                );
160
161                let repo_url = example.base.url.clone();
162                if repo_urls.insert(repo_url.clone()) {
163                    let repo_path = repo_path_for_url(&repo_url);
164
165                    if !repo_path.join(".git").is_dir() {
166                        println!(
167                            "{:<width$}  < {}",
168                            "↓ Cloning",
169                            repo_url,
170                            width = max_name_width
171                        );
172
173                        let git_task = cx.spawn(async move |_cx| {
174                            std::fs::create_dir_all(&repo_path)?;
175                            run_git(&repo_path, &["init"]).await?;
176                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
177                        });
178
179                        clone_tasks.push(git_task);
180                    } else {
181                        println!(
182                            "{:<width$}  < {}",
183                            "✔︎ Already cloned",
184                            repo_url,
185                            width = max_name_width
186                        );
187
188                        let actual_origin =
189                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
190                        if actual_origin != repo_url {
191                            return Err(anyhow!(
192                                "remote origin {} does not match expected origin {}",
193                                actual_origin,
194                                repo_url,
195                            ));
196                        }
197                    }
198                }
199            }
200
201            future::join_all(clone_tasks).await;
202
203            for example in examples.iter() {
204                example.setup().await?;
205            }
206
207            let tasks = examples
208                .into_iter()
209                .map(|example| {
210                    let app_state = app_state.clone();
211                    let model = model.clone();
212                    cx.spawn(async move |cx| {
213                        (run_example(&example, model, app_state, cx).await, example)
214                    })
215                })
216                .collect::<Vec<_>>();
217
218            let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
219
220            println!("\n\n");
221            println!("========================================");
222            println!("              EVAL RESULTS              ");
223            println!("========================================");
224            println!("");
225
226            let mut judge_scores = Vec::new();
227
228            for (result, example) in results {
229                match result {
230                    Err(err) => {
231                        println!("💥 {}{:?}", example.log_prefix, err);
232                    }
233                    Ok(judge_output) => {
234                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
235
236                        println!(
237                            "{} {}{}",
238                            SCORES[judge_output.score.min(5) as usize],
239                            example.log_prefix,
240                            judge_output.score,
241                        );
242                        judge_scores.push(judge_output.score);
243                    }
244                }
245                println!(
246                    "{}    > {}",
247                    " ".repeat(max_name_width),
248                    example.output_file_path.display()
249                );
250            }
251
252            let score_count = judge_scores.len();
253            let average_score = judge_scores
254                .into_iter()
255                .map(|score| score as f32)
256                .sum::<f32>()
257                / (score_count as f32);
258            println!("\nAverage score: {average_score}");
259
260            cx.update(|cx| cx.quit())
261        })
262        .detach_and_log_err(cx);
263    });
264}
265
266async fn run_example(
267    example: &Example,
268    model: Arc<dyn LanguageModel>,
269    app_state: Arc<AgentAppState>,
270    cx: &mut AsyncApp,
271) -> Result<JudgeOutput> {
272    cx.update(|cx| example.run(model.clone(), app_state, cx))?
273        .await?;
274    let diff = example.repository_diff().await?;
275    example.judge(model, diff, cx).await
276}
277
278fn list_all_examples() -> Result<Vec<PathBuf>> {
279    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
280    let entries = std::fs::read_dir(path).unwrap();
281    let mut result_paths = Vec::new();
282    for entry in entries {
283        let entry = entry?;
284        let path = entry.path();
285        if path.is_dir() {
286            result_paths.push(path);
287        }
288    }
289    Ok(result_paths)
290}
291
292/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
293pub struct AgentAppState {
294    pub languages: Arc<LanguageRegistry>,
295    pub client: Arc<Client>,
296    pub user_store: Entity<UserStore>,
297    pub fs: Arc<dyn fs::Fs>,
298    pub node_runtime: NodeRuntime,
299
300    // Additional fields not present in `workspace::AppState`.
301    pub prompt_builder: Arc<PromptBuilder>,
302}
303
304pub fn init(cx: &mut App) -> Arc<AgentAppState> {
305    release_channel::init(SemanticVersion::default(), cx);
306    gpui_tokio::init(cx);
307
308    let mut settings_store = SettingsStore::new(cx);
309    settings_store
310        .set_default_settings(settings::default_settings().as_ref(), cx)
311        .unwrap();
312    cx.set_global(settings_store);
313    client::init_settings(cx);
314
315    // Set User-Agent so we can download language servers from GitHub
316    let user_agent = format!(
317        "Zed/{} ({}; {})",
318        AppVersion::global(cx),
319        std::env::consts::OS,
320        std::env::consts::ARCH
321    );
322    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
323    let proxy_url = proxy_str
324        .as_ref()
325        .and_then(|input| input.parse::<Uri>().ok())
326        .or_else(read_proxy_from_env);
327    let http = {
328        let _guard = Tokio::handle(cx).enter();
329
330        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
331            .expect("could not start HTTP client")
332    };
333    cx.set_http_client(Arc::new(http));
334
335    Project::init_settings(cx);
336
337    let client = Client::production(cx);
338    cx.set_http_client(client.http_client().clone());
339
340    let git_binary_path = None;
341    let fs = Arc::new(RealFs::new(
342        git_binary_path,
343        cx.background_executor().clone(),
344    ));
345
346    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
347    languages.set_language_server_download_dir(paths::languages_dir().clone());
348    let languages = Arc::new(languages);
349
350    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
351
352    extension::init(cx);
353
354    let (tx, rx) = async_watch::channel(None);
355    cx.observe_global::<SettingsStore>(move |cx| {
356        let settings = &ProjectSettings::get_global(cx).node;
357        let options = NodeBinaryOptions {
358            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
359            allow_binary_download: true,
360            use_paths: settings.path.as_ref().map(|node_path| {
361                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
362                let npm_path = settings
363                    .npm_path
364                    .as_ref()
365                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
366                (
367                    node_path.clone(),
368                    npm_path.unwrap_or_else(|| {
369                        let base_path = PathBuf::new();
370                        node_path.parent().unwrap_or(&base_path).join("npm")
371                    }),
372                )
373            }),
374        };
375        tx.send(Some(options)).log_err();
376    })
377    .detach();
378    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
379
380    let extension_host_proxy = ExtensionHostProxy::global(cx);
381
382    language::init(cx);
383    language_extension::init(extension_host_proxy.clone(), languages.clone());
384    language_model::init(client.clone(), cx);
385    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
386    languages::init(languages.clone(), node_runtime.clone(), cx);
387    assistant_tools::init(client.http_client().clone(), cx);
388    context_server::init(cx);
389    let stdout_is_a_pty = false;
390    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
391    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
392
393    AssistantSettings::override_global(
394        AssistantSettings {
395            always_allow_tool_actions: true,
396            ..AssistantSettings::get_global(cx).clone()
397        },
398        cx,
399    );
400
401    Arc::new(AgentAppState {
402        languages,
403        client,
404        user_store,
405        fs,
406        node_runtime,
407        prompt_builder,
408    })
409}
410
411pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
412    let model_registry = LanguageModelRegistry::read_global(cx);
413    let model = model_registry
414        .available_models(cx)
415        .find(|model| model.id().0 == model_name);
416
417    let Some(model) = model else {
418        return Err(anyhow!(
419            "No language model named {} was available. Available models: {}",
420            model_name,
421            model_registry
422                .available_models(cx)
423                .map(|model| model.id().0.clone())
424                .collect::<Vec<_>>()
425                .join(", ")
426        ));
427    };
428
429    Ok(model)
430}
431
432pub fn authenticate_model_provider(
433    provider_id: LanguageModelProviderId,
434    cx: &mut App,
435) -> Task<std::result::Result<(), AuthenticateError>> {
436    let model_registry = LanguageModelRegistry::read_global(cx);
437    let model_provider = model_registry.provider(&provider_id).unwrap();
438    model_provider.authenticate(cx)
439}