main.rs

  1mod eval;
  2mod headless_assistant;
  3mod judge;
  4
  5use clap::Parser;
  6use eval::{Eval, EvalOutput};
  7use futures::future;
  8use gpui::{Application, AsyncApp};
  9use headless_assistant::{HeadlessAppState, authenticate_model_provider, find_model};
 10use itertools::Itertools;
 11use judge::Judge;
 12use language_model::{LanguageModel, LanguageModelRegistry};
 13use regex::Regex;
 14use reqwest_client::ReqwestClient;
 15use std::{cmp, path::PathBuf, sync::Arc};
 16
 17#[derive(Parser, Debug)]
 18#[command(
 19    name = "assistant_eval",
 20    disable_version_flag = true,
 21    before_help = "Tool eval runner"
 22)]
 23struct Args {
 24    /// Regexes to match the names of evals to run.
 25    eval_name_regexes: Vec<String>,
 26    /// Runs all evals in `evaluation_data`, causes the regex to be ignored.
 27    #[arg(long)]
 28    all: bool,
 29    /// Name of the model (default: "claude-3-7-sonnet-latest")
 30    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 31    model_name: String,
 32    /// Name of the editor model (default: value of `--model_name`).
 33    #[arg(long)]
 34    editor_model_name: Option<String>,
 35    /// Name of the judge model (default: value of `--model_name`).
 36    #[arg(long)]
 37    judge_model_name: Option<String>,
 38    /// Number of evaluations to run concurrently (default: 10)
 39    #[arg(short, long, default_value = "10")]
 40    concurrency: usize,
 41}
 42
 43fn main() {
 44    env_logger::init();
 45    let args = Args::parse();
 46    let http_client = Arc::new(ReqwestClient::new());
 47    let app = Application::headless().with_http_client(http_client.clone());
 48
 49    let crate_dir = PathBuf::from("../zed-agent-bench");
 50    let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
 51
 52    let repos_dir = crate_dir.join("repos");
 53    if !repos_dir.exists() {
 54        std::fs::create_dir_all(&repos_dir).unwrap();
 55    }
 56    let repos_dir = repos_dir.canonicalize().unwrap();
 57
 58    let all_evals = std::fs::read_dir(&evaluation_data_dir)
 59        .unwrap()
 60        .map(|path| path.unwrap().file_name().to_string_lossy().to_string())
 61        .collect::<Vec<_>>();
 62
 63    let evals_to_run = if args.all {
 64        all_evals
 65    } else {
 66        args.eval_name_regexes
 67            .into_iter()
 68            .map(|regex_string| Regex::new(&regex_string).unwrap())
 69            .flat_map(|regex| {
 70                all_evals
 71                    .iter()
 72                    .filter(|eval_name| regex.is_match(eval_name))
 73                    .cloned()
 74                    .collect::<Vec<_>>()
 75            })
 76            .collect::<Vec<_>>()
 77    };
 78
 79    if evals_to_run.is_empty() {
 80        panic!("Names of evals to run must be provided or `--all` specified");
 81    }
 82
 83    println!("Will run the following evals: {evals_to_run:?}");
 84    println!("Running up to {} evals concurrently", args.concurrency);
 85
 86    let editor_model_name = if let Some(model_name) = args.editor_model_name {
 87        model_name
 88    } else {
 89        args.model_name.clone()
 90    };
 91
 92    let judge_model_name = if let Some(model_name) = args.judge_model_name {
 93        model_name
 94    } else {
 95        args.model_name.clone()
 96    };
 97
 98    app.run(move |cx| {
 99        let app_state = headless_assistant::init(cx);
100
101        let model = find_model(&args.model_name, cx).unwrap();
102        let editor_model = find_model(&editor_model_name, cx).unwrap();
103        let judge_model = find_model(&judge_model_name, cx).unwrap();
104
105        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
106            registry.set_active_model(Some(model.clone()), cx);
107            registry.set_editor_model(Some(editor_model.clone()), cx);
108        });
109
110        let model_provider_id = model.provider_id();
111        let editor_model_provider_id = editor_model.provider_id();
112        let judge_model_provider_id = judge_model.provider_id();
113
114        cx.spawn(async move |cx| {
115            // Authenticate all model providers first
116            cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
117                .unwrap()
118                .await
119                .unwrap();
120            cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
121                .unwrap()
122                .await
123                .unwrap();
124            cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
125                .unwrap()
126                .await
127                .unwrap();
128
129            let eval_load_futures = evals_to_run
130                .into_iter()
131                .map(|eval_name| {
132                    let eval_path = evaluation_data_dir.join(&eval_name);
133                    let load_future = Eval::load(eval_name.clone(), eval_path, &repos_dir);
134                    async move {
135                        match load_future.await {
136                            Ok(eval) => Some(eval),
137                            Err(err) => {
138                                // TODO: Persist errors / surface errors at the end.
139                                println!("Error loading {eval_name}: {err}");
140                                None
141                            }
142                        }
143                    }
144                })
145                .collect::<Vec<_>>();
146
147            let loaded_evals = future::join_all(eval_load_futures)
148                .await
149                .into_iter()
150                .flatten()
151                .collect::<Vec<_>>();
152
153            // The evals need to be loaded and grouped by URL before concurrently running, since
154            // evals that use the same remote URL will use the same working directory.
155            let mut evals_grouped_by_url: Vec<Vec<Eval>> = loaded_evals
156                .into_iter()
157                .map(|eval| (eval.eval_setup.url.clone(), eval))
158                .into_group_map()
159                .into_values()
160                .collect::<Vec<_>>();
161
162            // Sort groups in descending order, so that bigger groups start first.
163            evals_grouped_by_url.sort_by_key(|evals| cmp::Reverse(evals.len()));
164
165            let result_futures = evals_grouped_by_url
166                .into_iter()
167                .map(|evals| {
168                    let model = model.clone();
169                    let judge_model = judge_model.clone();
170                    let app_state = app_state.clone();
171                    let cx = cx.clone();
172
173                    async move {
174                        let mut results = Vec::new();
175                        for eval in evals {
176                            let name = eval.name.clone();
177                            println!("Starting eval named {}", name);
178                            let result = run_eval(
179                                eval,
180                                model.clone(),
181                                judge_model.clone(),
182                                app_state.clone(),
183                                cx.clone(),
184                            )
185                            .await;
186                            results.push((name, result));
187                        }
188                        results
189                    }
190                })
191                .collect::<Vec<_>>();
192
193            let results = future::join_all(result_futures)
194                .await
195                .into_iter()
196                .flatten()
197                .collect::<Vec<_>>();
198
199            // Process results in order of completion
200            for (eval_name, result) in results {
201                match result {
202                    Ok((eval_output, judge_output)) => {
203                        println!("Generated diff for {eval_name}:\n");
204                        println!("{}\n", eval_output.diff);
205                        println!("Last message for {eval_name}:\n");
206                        println!("{}\n", eval_output.last_message);
207                        println!("Elapsed time: {:?}", eval_output.elapsed_time);
208                        println!(
209                            "Assistant response count: {}",
210                            eval_output.assistant_response_count
211                        );
212                        println!("Tool use counts: {:?}", eval_output.tool_use_counts);
213                        println!("Judge output for {eval_name}: {judge_output}");
214                    }
215                    Err(err) => {
216                        // TODO: Persist errors / surface errors at the end.
217                        println!("Error running {eval_name}: {err}");
218                    }
219                }
220            }
221
222            cx.update(|cx| cx.quit()).unwrap();
223        })
224        .detach();
225    });
226
227    println!("Done running evals");
228}
229
230async fn run_eval(
231    eval: Eval,
232    model: Arc<dyn LanguageModel>,
233    judge_model: Arc<dyn LanguageModel>,
234    app_state: Arc<HeadlessAppState>,
235    cx: AsyncApp,
236) -> anyhow::Result<(EvalOutput, String)> {
237    let path = eval.path.clone();
238    let judge = Judge::load(&path, judge_model).await?;
239    let eval_output = cx.update(|cx| eval.run(app_state, model, cx))?.await?;
240    let judge_output = cx.update(|cx| judge.run(&eval_output, cx))?.await?;
241    eval_output.save_to_directory(&path, judge_output.to_string())?;
242    Ok((eval_output, judge_output))
243}