eval.rs

  1mod example;
  2
  3use assistant_settings::AssistantSettings;
  4use client::{Client, ProxySettings, UserStore};
  5pub(crate) use example::*;
  6
  7use ::fs::RealFs;
  8use anyhow::{Result, anyhow};
  9use clap::Parser;
 10use extension::ExtensionHostProxy;
 11use futures::future;
 12use gpui::http_client::{Uri, read_proxy_from_env};
 13use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task};
 14use gpui_tokio::Tokio;
 15use language::LanguageRegistry;
 16use language_model::{
 17    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 18};
 19use node_runtime::{NodeBinaryOptions, NodeRuntime};
 20use project::Project;
 21use project::project_settings::ProjectSettings;
 22use prompt_store::PromptBuilder;
 23use release_channel::AppVersion;
 24use reqwest_client::ReqwestClient;
 25use settings::{Settings, SettingsStore};
 26use std::collections::HashSet;
 27use std::path::{Path, PathBuf};
 28use std::sync::Arc;
 29use util::ResultExt as _;
 30
 31pub const RUNS_DIR: &str = "./crates/eval/runs";
 32
 33#[derive(Parser, Debug)]
 34#[command(name = "eval", disable_version_flag = true)]
 35struct Args {
 36    /// Runs all examples that contain these substrings. If unspecified, all examples are run.
 37    #[arg(value_name = "EXAMPLE_SUBSTRING")]
 38    examples: Vec<String>,
 39    /// Model to use (default: "claude-3-7-sonnet-latest")
 40    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 41    model: String,
 42}
 43
 44fn main() {
 45    env_logger::init();
 46
 47    let args = Args::parse();
 48    let all_available_examples = list_all_examples().unwrap();
 49    let example_paths = all_available_examples
 50        .iter()
 51        .filter_map(|example_path| {
 52            let name = example_path.file_name()?.to_string_lossy();
 53            if args.examples.is_empty()
 54                || args
 55                    .examples
 56                    .iter()
 57                    .any(|name_substring| name.contains(name_substring))
 58            {
 59                Some(example_path.clone())
 60            } else {
 61                None
 62            }
 63        })
 64        .collect::<Vec<_>>();
 65
 66    let http_client = Arc::new(ReqwestClient::new());
 67    let app = Application::headless().with_http_client(http_client.clone());
 68
 69    app.run(move |cx| {
 70        let app_state = init(cx);
 71
 72        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
 73
 74        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 75            registry.set_default_model(Some(model.clone()), cx);
 76        });
 77
 78        let model_provider_id = model.provider_id();
 79
 80        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
 81
 82        cx.spawn(async move |cx| {
 83            authenticate.await.unwrap();
 84
 85            std::fs::create_dir_all(REPOS_DIR)?;
 86            std::fs::create_dir_all(WORKTREES_DIR)?;
 87
 88            let run_dir = Path::new(RUNS_DIR).join(format!(
 89                "{}",
 90                chrono::Local::now().format("%Y-%m-%d_%H-%M-%S")
 91            ));
 92            std::fs::create_dir_all(&run_dir)?;
 93
 94            let mut examples = Vec::new();
 95            for example_path in example_paths {
 96                let example = Example::load_from_directory(&example_path, &run_dir)?;
 97                examples.push((example_path, example));
 98            }
 99            let mut repo_urls = HashSet::new();
100
101            let mut clone_tasks = Vec::new();
102
103            for (_, example) in examples.iter() {
104                let repo_url = example.base.url.clone();
105                if repo_urls.insert(repo_url.clone()) {
106                    let repo_path = repo_path_for_url(&repo_url);
107
108                    if !repo_path.join(".git").is_dir() {
109                        println!("Cloning: {}", repo_url);
110
111                        let git_task = cx.spawn(async move |_cx| {
112                            std::fs::create_dir_all(&repo_path)?;
113                            run_git(&repo_path, &["init"]).await?;
114                            run_git(&repo_path, &["remote", "add", "origin", &repo_url]).await
115                        });
116
117                        clone_tasks.push(git_task);
118                    } else {
119                        println!("Already cloned: {}", repo_url);
120
121                        let actual_origin =
122                            run_git(&repo_path, &["remote", "get-url", "origin"]).await?;
123                        if actual_origin != repo_url {
124                            return Err(anyhow!(
125                                "remote origin {} does not match expected origin {}",
126                                actual_origin,
127                                repo_url,
128                            ));
129                        }
130                    }
131                }
132            }
133
134            future::join_all(clone_tasks).await;
135
136            let tasks = examples
137                .into_iter()
138                .map(|(example_path, example)| {
139                    let app_state = app_state.clone();
140                    let model = model.clone();
141                    cx.spawn(async move |cx| {
142                        (
143                            example_path,
144                            run_example(example, model, app_state, cx).await,
145                        )
146                    })
147                })
148                .collect::<Vec<_>>();
149
150            let results: Vec<(PathBuf, Result<JudgeOutput>)> = future::join_all(tasks).await;
151
152            println!("\n\n");
153            println!("========================================");
154            println!("              EVAL RESULTS              ");
155            println!("========================================");
156            println!("");
157
158            let mut judge_scores = Vec::new();
159
160            for (example_path, result) in results {
161                let example_name = example_path.file_name().unwrap().to_string_lossy();
162                match result {
163                    Err(err) => {
164                        println!("💥 {:<30}: {:?}", example_name, err);
165                    }
166                    Ok(judge_output) => {
167                        const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"];
168
169                        println!(
170                            "{} {:<30}: {}",
171                            SCORES[judge_output.score.min(5) as usize],
172                            example_name,
173                            judge_output.score,
174                        );
175                        judge_scores.push(judge_output.score);
176                    }
177                }
178            }
179
180            let score_count = judge_scores.len();
181            let average_score = judge_scores
182                .into_iter()
183                .map(|score| score as f32)
184                .sum::<f32>()
185                / (score_count as f32);
186            println!("\nAverage score: {average_score}");
187
188            cx.update(|cx| cx.quit())
189        })
190        .detach_and_log_err(cx);
191    });
192}
193
194async fn run_example(
195    mut example: Example,
196    model: Arc<dyn LanguageModel>,
197    app_state: Arc<AgentAppState>,
198    cx: &mut AsyncApp,
199) -> Result<JudgeOutput> {
200    example.setup().await?;
201    cx.update(|cx| example.run(model.clone(), app_state, cx))?
202        .await?;
203    let diff = example.repository_diff().await?;
204    example.judge(model, diff, cx).await
205}
206
207fn list_all_examples() -> Result<Vec<PathBuf>> {
208    let path = std::fs::canonicalize(EXAMPLES_DIR).unwrap();
209    let entries = std::fs::read_dir(path).unwrap();
210    let mut result_paths = Vec::new();
211    for entry in entries {
212        let entry = entry?;
213        let path = entry.path();
214        if path.is_dir() {
215            result_paths.push(path);
216        }
217    }
218    Ok(result_paths)
219}
220
221/// Subset of `workspace::AppState` needed by `HeadlessAssistant`, with additional fields.
222pub struct AgentAppState {
223    pub languages: Arc<LanguageRegistry>,
224    pub client: Arc<Client>,
225    pub user_store: Entity<UserStore>,
226    pub fs: Arc<dyn fs::Fs>,
227    pub node_runtime: NodeRuntime,
228
229    // Additional fields not present in `workspace::AppState`.
230    pub prompt_builder: Arc<PromptBuilder>,
231}
232
233pub fn init(cx: &mut App) -> Arc<AgentAppState> {
234    release_channel::init(SemanticVersion::default(), cx);
235    gpui_tokio::init(cx);
236
237    let mut settings_store = SettingsStore::new(cx);
238    settings_store
239        .set_default_settings(settings::default_settings().as_ref(), cx)
240        .unwrap();
241    cx.set_global(settings_store);
242    client::init_settings(cx);
243
244    // Set User-Agent so we can download language servers from GitHub
245    let user_agent = format!(
246        "Zed/{} ({}; {})",
247        AppVersion::global(cx),
248        std::env::consts::OS,
249        std::env::consts::ARCH
250    );
251    let proxy_str = ProxySettings::get_global(cx).proxy.to_owned();
252    let proxy_url = proxy_str
253        .as_ref()
254        .and_then(|input| input.parse::<Uri>().ok())
255        .or_else(read_proxy_from_env);
256    let http = {
257        let _guard = Tokio::handle(cx).enter();
258
259        ReqwestClient::proxy_and_user_agent(proxy_url, &user_agent)
260            .expect("could not start HTTP client")
261    };
262    cx.set_http_client(Arc::new(http));
263
264    Project::init_settings(cx);
265
266    let client = Client::production(cx);
267    cx.set_http_client(client.http_client().clone());
268
269    let git_binary_path = None;
270    let fs = Arc::new(RealFs::new(
271        git_binary_path,
272        cx.background_executor().clone(),
273    ));
274
275    let mut languages = LanguageRegistry::new(cx.background_executor().clone());
276    languages.set_language_server_download_dir(paths::languages_dir().clone());
277    let languages = Arc::new(languages);
278
279    let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
280
281    extension::init(cx);
282
283    let (tx, rx) = async_watch::channel(None);
284    cx.observe_global::<SettingsStore>(move |cx| {
285        let settings = &ProjectSettings::get_global(cx).node;
286        let options = NodeBinaryOptions {
287            allow_path_lookup: !settings.ignore_system_version.unwrap_or_default(),
288            allow_binary_download: true,
289            use_paths: settings.path.as_ref().map(|node_path| {
290                let node_path = PathBuf::from(shellexpand::tilde(node_path).as_ref());
291                let npm_path = settings
292                    .npm_path
293                    .as_ref()
294                    .map(|path| PathBuf::from(shellexpand::tilde(&path).as_ref()));
295                (
296                    node_path.clone(),
297                    npm_path.unwrap_or_else(|| {
298                        let base_path = PathBuf::new();
299                        node_path.parent().unwrap_or(&base_path).join("npm")
300                    }),
301                )
302            }),
303        };
304        tx.send(Some(options)).log_err();
305    })
306    .detach();
307    let node_runtime = NodeRuntime::new(client.http_client().clone(), rx);
308
309    let extension_host_proxy = ExtensionHostProxy::global(cx);
310
311    language::init(cx);
312    language_extension::init(extension_host_proxy.clone(), languages.clone());
313    language_model::init(client.clone(), cx);
314    language_models::init(user_store.clone(), client.clone(), fs.clone(), cx);
315    languages::init(languages.clone(), node_runtime.clone(), cx);
316    assistant_tools::init(client.http_client().clone(), cx);
317    context_server::init(cx);
318    let stdout_is_a_pty = false;
319    let prompt_builder = PromptBuilder::load(fs.clone(), stdout_is_a_pty, cx);
320    agent::init(fs.clone(), client.clone(), prompt_builder.clone(), cx);
321
322    AssistantSettings::override_global(
323        AssistantSettings {
324            always_allow_tool_actions: true,
325            ..AssistantSettings::get_global(cx).clone()
326        },
327        cx,
328    );
329
330    Arc::new(AgentAppState {
331        languages,
332        client,
333        user_store,
334        fs,
335        node_runtime,
336        prompt_builder,
337    })
338}
339
340pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
341    let model_registry = LanguageModelRegistry::read_global(cx);
342    let model = model_registry
343        .available_models(cx)
344        .find(|model| model.id().0 == model_name);
345
346    let Some(model) = model else {
347        return Err(anyhow!(
348            "No language model named {} was available. Available models: {}",
349            model_name,
350            model_registry
351                .available_models(cx)
352                .map(|model| model.id().0.clone())
353                .collect::<Vec<_>>()
354                .join(", ")
355        ));
356    };
357
358    Ok(model)
359}
360
361pub fn authenticate_model_provider(
362    provider_id: LanguageModelProviderId,
363    cx: &mut App,
364) -> Task<std::result::Result<(), AuthenticateError>> {
365    let model_registry = LanguageModelRegistry::read_global(cx);
366    let model_provider = model_registry.provider(&provider_id).unwrap();
367    model_provider.authenticate(cx)
368}