eval.rs

  1mod example;
  2
  3use assistant_settings::AssistantSettings;
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6
  7use ::fs::RealFs;
  8use anyhow::{Result, anyhow};
  9use clap::Parser;
 10use extension::ExtensionHostProxy;
 11use futures::future;
 12use gpui::http_client::{Uri, read_proxy_from_env};
 13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
 14use gpui_tokio::Tokio;
 15use language::LanguageRegistry;
 16use language_model::{
 17    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 18};
 19use node_runtime::{NodeBinaryOptions, NodeRuntime};
 20use project::Project;
 21use project::project_settings::ProjectSettings;
 22use prompt_store::PromptBuilder;
 23use release_channel::AppVersion;
 24use reqwest_client::ReqwestClient;
 25use settings::{Settings, SettingsStore};
 26use std::collections::HashSet;
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42    /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
 43    #[arg(long, value_delimiter = ',')]
 44    languages: Option<Vec<String>>,
 45}
 46
 47fn main() {
 48    env_logger::init();
 49
 50    let args = Args::parse();
 51    let all_available_examples = list_all_examples().unwrap();
 52    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 53
 54    let example_paths = all_available_examples
 55        .iter()
 56        .filter_map(|example_path| {
 57            let name = example_path.file_name()?.to_string_lossy();
 58            if args.examples.is_empty()
 59                || args
 60                    .examples
 61                    .iter()
 62                    .any(|name_substring| name.contains(name_substring))
 63            {
 64                Some(example_path.clone())
 65            } else {
 66                None
 67            }
 68        })
 69        .collect::<Vec<_>>();
 70
 71    let http_client = Arc::new(ReqwestClient::new());
 72    let app = Application::headless().with_http_client(http_client.clone());
 73
 74    app.run(move |cx| {
 75        let app_state = init(cx);
 76
 77        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 78
 79        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 80            registry.set_default_model(Some(model.clone()), cx);
 81        });
 82
 83        let model_provider_id = model.provider_id();
 84
 85        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 86
 87        cx.spawn(async move |cx| {
 88            authenticate.await.unwrap();
 89
 90            std::fs::create_dir_all(REPOS_DIR)?;
 91            std::fs::create_dir_all(WORKTREES_DIR)?;
 92
 93            let run_dir = Path::new(RUNS_DIR).join(format!(
 94                "{}",
 95                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 96            ));
 97            std::fs::create_dir_all(&run_dir)?;
 98
 99            let mut examples = Vec::new();
100            for example_path in example_paths {
101                let example = Example::load_from_directory(&example_path, &run_dir)?;
102
103                if !example
104                    .base
105                    .language_extension
106                    .as_ref()
107                    .map_or(false, |lang| languages.contains(lang))
108                {
109                    println!("Skipping {}", example.name);
110                    continue;
111                }
112
113                examples.push((example_path, example));
114            }
115            let mut repo_urls = HashSet::new();
116
117            let mut clone_tasks = Vec::new();
118
119            for (_, example) in examples.iter() {
120                let repo_url = example.base.url.clone();
121                if repo_urls.insert(repo_url.clone()) {
122                    let repo_path = repo_path_for_url(&repo_url);
123
124                    if !repo_path.join(".git").is_dir() {
125                        println!("Cloning: {}", repo_url);
126
127                        let git_task = cx.spawn(async move |_cx| {
128                            std::fs::create_dir_all(&repo_path)?;
129                            run_git(&repo_path, &["init"]).await?;
130                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
131                        });
132
133                        clone_tasks.push(git_task);
134                    } else {
135                        println!("Already cloned: {}", repo_url);
136
137                        let actual_origin =
138                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
139                        if actual_origin != repo_url {
140                            return Err(anyhow!(
141                                "remote origin {} does not match expected origin {}",
142                                actual_origin,
143                                repo_url,
144                            ));
145                        }
146                    }
147                }
148            }
149
150            future::join_all(clone_tasks).await;
151
152            for (_, example) in examples.iter() {
153                example.setup().await?;
154            }
155
156            let tasks = examples
157                .into_iter()
158                .map(|(example_path, example)| {
159                    let app_state = app_state.clone();
160                    let model = model.clone();
161                    cx.spawn(async move |cx| {
162                        (
163                            example_path,
164                            run_example(example, model, app_state, cx).await,
165                        )
166                    })
167                })
168                .collect::<Vec<_>>();
169
170            let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
171
172            println!("\n\n");
173            println!("========================================");
174            println!("              EVAL RESULTS              ");
175            println!("========================================");
176            println!("");
177
178            let mut judge_scores = Vec::new();
179
180            for (example_path, result) in results {
181                let example_name = example_path.file_name().unwrap().to_string_lossy();
182                match result {
183                    Err(err) => {
184                        println!("💥 {:<30}: {:?}", example_name, err);
185                    }
186                    Ok(judge_output) => {
187                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
188
189                        println!(
190                            "{} {:<30}: {}",
191                            SCORES[judge_output.score.min(5) as usize],
192                            example_name,
193                            judge_output.score,
194                        );
195                        judge_scores.push(judge_output.score);
196                    }
197                }
198            }
199
200            let score_count = judge_scores.len();
201            let average_score = judge_scores
202                .into_iter()
203                .map(|score| score as f32)
204                .sum::<f32>()
205                / (score_count as f32);
206            println!("\nAverage score: {average_score}");
207
208            cx.update(|cx| cx.quit())
209        })
210        .detach_and_log_err(cx);
211    });
212}
213
214async fn run_example(
215    mut example: Example,
216    model: Arc<dyn LanguageModel>,
217    app_state: Arc<AgentAppState>,
218    cx: &mut AsyncApp,
219) -> Result<JudgeOutput> {
220    cx.update(|cx| example.run(model.clone(), app_state, cx))?
221        .await?;
222    let diff = example.repository_diff().await?;
223    example.judge(model, diff, cx).await
224}
225
226fn list_all_examples() -> Result<Vec<PathBuf>> {
227    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
228    let entries = std::fs::read_dir(path).unwrap();
229    let mut result_paths = Vec::new();
230    for entry in entries {
231        let entry = entry?;
232        let path = entry.path();
233        if path.is_dir() {
234            result_paths.push(path);
235        }
236    }
237    Ok(result_paths)
238}
239
240/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
241pub struct AgentAppState {
242    pub languages: Arc<LanguageRegistry>,
243    pub client: Arc<Client>,
244    pub user_store: Entity<UserStore>,
245    pub fs: Arc<dyn fs::Fs>,
246    pub node_runtime: NodeRuntime,
247
248    // Additional fields not present in `workspace::AppState`.
249    pub prompt_builder: Arc<PromptBuilder>,
250}
251
252pub fn init(cx: &mut App) -> Arc<AgentAppState> {
253    release_channel::init(SemanticVersion::default(), cx);
254    gpui_tokio::init(cx);
255
256    let mut settings_store = SettingsStore::new(cx);
257    settings_store
258        .set_default_settings(settings::default_settings().as_ref(), cx)
259        .unwrap();
260    cx.set_global(settings_store);
261    client::init_settings(cx);
262
263    // Set User-Agent so we can download language servers from GitHub
264    let user_agent = format!(
265        "Zed/{} ({}; {})",
266        AppVersion::global(cx),
267        std::env::consts::OS,
268        std::env::consts::ARCH
269    );
270    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
271    let proxy_url = proxy_str
272        .as_ref()
273        .and_then(|input| input.parse::<Uri>().ok())
274        .or_else(read_proxy_from_env);
275    let http = {
276        let _guard = Tokio::handle(cx).enter();
277
278        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
279            .expect("could not start HTTP client")
280    };
281    cx.set_http_client(Arc::new(http));
282
283    Project::init_settings(cx);
284
285    let client = Client::production(cx);
286    cx.set_http_client(client.http_client().clone());
287
288    let git_binary_path = None;
289    let fs = Arc::new(RealFs::new(
290        git_binary_path,
291        cx.background_executor().clone(),
292    ));
293
294    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
295    languages.set_language_server_download_dir(paths::languages_dir().clone());
296    let languages = Arc::new(languages);
297
298    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
299
300    extension::init(cx);
301
302    let (tx, rx) = async_watch::channel(None);
303    cx.observe_global::<SettingsStore>(move |cx| {
304        let settings = &ProjectSettings::get_global(cx).node;
305        let options = NodeBinaryOptions {
306            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
307            allow_binary_download: true,
308            use_paths: settings.path.as_ref().map(|node_path| {
309                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
310                let npm_path = settings
311                    .npm_path
312                    .as_ref()
313                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
314                (
315                    node_path.clone(),
316                    npm_path.unwrap_or_else(|| {
317                        let base_path = PathBuf::new();
318                        node_path.parent().unwrap_or(&base_path).join("npm")
319                    }),
320                )
321            }),
322        };
323        tx.send(Some(options)).log_err();
324    })
325    .detach();
326    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
327
328    let extension_host_proxy = ExtensionHostProxy::global(cx);
329
330    language::init(cx);
331    language_extension::init(extension_host_proxy.clone(), languages.clone());
332    language_model::init(client.clone(), cx);
333    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
334    languages::init(languages.clone(), node_runtime.clone(), cx);
335    assistant_tools::init(client.http_client().clone(), cx);
336    context_server::init(cx);
337    let stdout_is_a_pty = false;
338    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
339    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
340
341    AssistantSettings::override_global(
342        AssistantSettings {
343            always_allow_tool_actions: true,
344            ..AssistantSettings::get_global(cx).clone()
345        },
346        cx,
347    );
348
349    Arc::new(AgentAppState {
350        languages,
351        client,
352        user_store,
353        fs,
354        node_runtime,
355        prompt_builder,
356    })
357}
358
359pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
360    let model_registry = LanguageModelRegistry::read_global(cx);
361    let model = model_registry
362        .available_models(cx)
363        .find(|model| model.id().0 == model_name);
364
365    let Some(model) = model else {
366        return Err(anyhow!(
367            "No language model named {} was available. Available models: {}",
368            model_name,
369            model_registry
370                .available_models(cx)
371                .map(|model| model.id().0.clone())
372                .collect::<Vec<_>>()
373                .join(", ")
374        ));
375    };
376
377    Ok(model)
378}
379
380pub fn authenticate_model_provider(
381    provider_id: LanguageModelProviderId,
382    cx: &mut App,
383) -> Task<std::result::Result<(), AuthenticateError>> {
384    let model_registry = LanguageModelRegistry::read_global(cx);
385    let model_provider = model_registry.provider(&provider_id).unwrap();
386    model_provider.authenticate(cx)
387}