eval.rs

  1mod example;
  2
  3use assistant_settings::AssistantSettings;
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6
  7use ::fs::RealFs;
  8use anyhow::{Result, anyhow};
  9use clap::Parser;
 10use extension::ExtensionHostProxy;
 11use futures::future;
 12use gpui::http_client::{Uri, read_proxy_from_env};
 13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
 14use gpui_tokio::Tokio;
 15use language::LanguageRegistry;
 16use language_model::{
 17    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 18};
 19use node_runtime::{NodeBinaryOptions, NodeRuntime};
 20use project::Project;
 21use project::project_settings::ProjectSettings;
 22use prompt_store::PromptBuilder;
 23use release_channel::AppVersion;
 24use reqwest_client::ReqwestClient;
 25use settings::{Settings, SettingsStore};
 26use std::collections::HashSet;
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42    /// Languages to run (comma-separated, e.g. "js,ts,py"). If unspecified, only Rust examples are run.
 43    #[arg(long, value_delimiter = ',')]
 44    languages: Option<Vec<String>>,
 45}
 46
 47fn main() {
 48    env_logger::init();
 49
 50    let args = Args::parse();
 51    let all_available_examples = list_all_examples().unwrap();
 52    let languages = args.languages.unwrap_or_else(|| vec!["rs".to_string()]);
 53
 54    let example_paths = all_available_examples
 55        .iter()
 56        .filter_map(|example_path| {
 57            let name = example_path.file_name()?.to_string_lossy();
 58            if args.examples.is_empty()
 59                || args
 60                    .examples
 61                    .iter()
 62                    .any(|name_substring| name.contains(name_substring))
 63            {
 64                Some(example_path.clone())
 65            } else {
 66                None
 67            }
 68        })
 69        .collect::<Vec<_>>();
 70
 71    let http_client = Arc::new(ReqwestClient::new());
 72    let app = Application::headless().with_http_client(http_client.clone());
 73
 74    app.run(move |cx| {
 75        let app_state = init(cx);
 76
 77        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 78
 79        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 80            registry.set_default_model(Some(model.clone()), cx);
 81        });
 82
 83        let model_provider_id = model.provider_id();
 84
 85        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 86
 87        cx.spawn(async move |cx| {
 88            authenticate.await.unwrap();
 89
 90            std::fs::create_dir_all(REPOS_DIR)?;
 91            std::fs::create_dir_all(WORKTREES_DIR)?;
 92
 93            let run_dir = Path::new(RUNS_DIR).join(format!(
 94                "{}",
 95                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 96            ));
 97            std::fs::create_dir_all(&run_dir)?;
 98
 99            let mut examples = Vec::new();
100            for example_path in example_paths {
101                let example = Example::load_from_directory(&example_path, &run_dir)?;
102
103                if !example
104                    .base
105                    .language_extension
106                    .as_ref()
107                    .map_or(false, |lang| languages.contains(lang))
108                {
109                    println!("Skipping {}", example.name);
110                    continue;
111                }
112
113                println!("{}> Logging to {:?}", example.name, example.log_file_path);
114
115                examples.push(example);
116            }
117            let mut repo_urls = HashSet::new();
118
119            let mut clone_tasks = Vec::new();
120
121            for example in examples.iter() {
122                let repo_url = example.base.url.clone();
123                if repo_urls.insert(repo_url.clone()) {
124                    let repo_path = repo_path_for_url(&repo_url);
125
126                    if !repo_path.join(".git").is_dir() {
127                        println!("Cloning: {}", repo_url);
128
129                        let git_task = cx.spawn(async move |_cx| {
130                            std::fs::create_dir_all(&repo_path)?;
131                            run_git(&repo_path, &["init"]).await?;
132                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
133                        });
134
135                        clone_tasks.push(git_task);
136                    } else {
137                        println!("Already cloned: {}", repo_url);
138
139                        let actual_origin =
140                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
141                        if actual_origin != repo_url {
142                            return Err(anyhow!(
143                                "remote origin {} does not match expected origin {}",
144                                actual_origin,
145                                repo_url,
146                            ));
147                        }
148                    }
149                }
150            }
151
152            future::join_all(clone_tasks).await;
153
154            for example in examples.iter() {
155                example.setup().await?;
156            }
157
158            let tasks = examples
159                .into_iter()
160                .map(|example| {
161                    let app_state = app_state.clone();
162                    let model = model.clone();
163                    cx.spawn(async move |cx| {
164                        (run_example(&example, model, app_state, cx).await, example)
165                    })
166                })
167                .collect::<Vec<_>>();
168
169            let results: Vec<(Result<JudgeOutput>, Example)> = future::join_all(tasks).await;
170
171            println!("\n\n");
172            println!("========================================");
173            println!("              EVAL RESULTS              ");
174            println!("========================================");
175            println!("");
176
177            let mut judge_scores = Vec::new();
178
179            for (result, example) in results {
180                println!("📜 {:<30}: {:?}", example.name, example.log_file_path);
181                match result {
182                    Err(err) => {
183                        println!("💥 {:<30}: {:?}", example.name, err);
184                    }
185                    Ok(judge_output) => {
186                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
187
188                        println!(
189                            "{} {:<30}: {}",
190                            SCORES[judge_output.score.min(5) as usize],
191                            example.name,
192                            judge_output.score,
193                        );
194                        judge_scores.push(judge_output.score);
195                    }
196                }
197            }
198
199            let score_count = judge_scores.len();
200            let average_score = judge_scores
201                .into_iter()
202                .map(|score| score as f32)
203                .sum::<f32>()
204                / (score_count as f32);
205            println!("\nAverage score: {average_score}");
206
207            cx.update(|cx| cx.quit())
208        })
209        .detach_and_log_err(cx);
210    });
211}
212
213async fn run_example(
214    example: &Example,
215    model: Arc<dyn LanguageModel>,
216    app_state: Arc<AgentAppState>,
217    cx: &mut AsyncApp,
218) -> Result<JudgeOutput> {
219    cx.update(|cx| example.run(model.clone(), app_state, cx))?
220        .await?;
221    let diff = example.repository_diff().await?;
222    example.judge(model, diff, cx).await
223}
224
225fn list_all_examples() -> Result<Vec<PathBuf>> {
226    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
227    let entries = std::fs::read_dir(path).unwrap();
228    let mut result_paths = Vec::new();
229    for entry in entries {
230        let entry = entry?;
231        let path = entry.path();
232        if path.is_dir() {
233            result_paths.push(path);
234        }
235    }
236    Ok(result_paths)
237}
238
239/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
240pub struct AgentAppState {
241    pub languages: Arc<LanguageRegistry>,
242    pub client: Arc<Client>,
243    pub user_store: Entity<UserStore>,
244    pub fs: Arc<dyn fs::Fs>,
245    pub node_runtime: NodeRuntime,
246
247    // Additional fields not present in `workspace::AppState`.
248    pub prompt_builder: Arc<PromptBuilder>,
249}
250
251pub fn init(cx: &mut App) -> Arc<AgentAppState> {
252    release_channel::init(SemanticVersion::default(), cx);
253    gpui_tokio::init(cx);
254
255    let mut settings_store = SettingsStore::new(cx);
256    settings_store
257        .set_default_settings(settings::default_settings().as_ref(), cx)
258        .unwrap();
259    cx.set_global(settings_store);
260    client::init_settings(cx);
261
262    // Set User-Agent so we can download language servers from GitHub
263    let user_agent = format!(
264        "Zed/{} ({}; {})",
265        AppVersion::global(cx),
266        std::env::consts::OS,
267        std::env::consts::ARCH
268    );
269    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
270    let proxy_url = proxy_str
271        .as_ref()
272        .and_then(|input| input.parse::<Uri>().ok())
273        .or_else(read_proxy_from_env);
274    let http = {
275        let _guard = Tokio::handle(cx).enter();
276
277        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
278            .expect("could not start HTTP client")
279    };
280    cx.set_http_client(Arc::new(http));
281
282    Project::init_settings(cx);
283
284    let client = Client::production(cx);
285    cx.set_http_client(client.http_client().clone());
286
287    let git_binary_path = None;
288    let fs = Arc::new(RealFs::new(
289        git_binary_path,
290        cx.background_executor().clone(),
291    ));
292
293    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
294    languages.set_language_server_download_dir(paths::languages_dir().clone());
295    let languages = Arc::new(languages);
296
297    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
298
299    extension::init(cx);
300
301    let (tx, rx) = async_watch::channel(None);
302    cx.observe_global::<SettingsStore>(move |cx| {
303        let settings = &ProjectSettings::get_global(cx).node;
304        let options = NodeBinaryOptions {
305            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
306            allow_binary_download: true,
307            use_paths: settings.path.as_ref().map(|node_path| {
308                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
309                let npm_path = settings
310                    .npm_path
311                    .as_ref()
312                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
313                (
314                    node_path.clone(),
315                    npm_path.unwrap_or_else(|| {
316                        let base_path = PathBuf::new();
317                        node_path.parent().unwrap_or(&base_path).join("npm")
318                    }),
319                )
320            }),
321        };
322        tx.send(Some(options)).log_err();
323    })
324    .detach();
325    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
326
327    let extension_host_proxy = ExtensionHostProxy::global(cx);
328
329    language::init(cx);
330    language_extension::init(extension_host_proxy.clone(), languages.clone());
331    language_model::init(client.clone(), cx);
332    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
333    languages::init(languages.clone(), node_runtime.clone(), cx);
334    assistant_tools::init(client.http_client().clone(), cx);
335    context_server::init(cx);
336    let stdout_is_a_pty = false;
337    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
338    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
339
340    AssistantSettings::override_global(
341        AssistantSettings {
342            always_allow_tool_actions: true,
343            ..AssistantSettings::get_global(cx).clone()
344        },
345        cx,
346    );
347
348    Arc::new(AgentAppState {
349        languages,
350        client,
351        user_store,
352        fs,
353        node_runtime,
354        prompt_builder,
355    })
356}
357
358pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
359    let model_registry = LanguageModelRegistry::read_global(cx);
360    let model = model_registry
361        .available_models(cx)
362        .find(|model| model.id().0 == model_name);
363
364    let Some(model) = model else {
365        return Err(anyhow!(
366            "No language model named {} was available. Available models: {}",
367            model_name,
368            model_registry
369                .available_models(cx)
370                .map(|model| model.id().0.clone())
371                .collect::<Vec<_>>()
372                .join(", ")
373        ));
374    };
375
376    Ok(model)
377}
378
379pub fn authenticate_model_provider(
380    provider_id: LanguageModelProviderId,
381    cx: &mut App,
382) -> Task<std::result::Result<(), AuthenticateError>> {
383    let model_registry = LanguageModelRegistry::read_global(cx);
384    let model_provider = model_registry.provider(&provider_id).unwrap();
385    model_provider.authenticate(cx)
386}