main.rs

  1mod eval;
  2mod get_exercise;
  3mod git_commands;
  4mod headless_assistant;
  5mod judge;
  6mod templates_eval;
  7
  8use clap::Parser;
  9use eval::{run_exercise_eval, save_eval_results};
 10use futures::stream::{self, StreamExt};
 11use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
 12use git_commands::read_base_sha;
 13use gpui::Application;
 14use headless_assistant::{authenticate_model_provider, find_model};
 15use language_model::LanguageModelRegistry;
 16use reqwest_client::ReqwestClient;
 17use std::{path::PathBuf, sync::Arc};
 18use templates_eval::all_templates;
 19
 20#[derive(Parser, Debug)]
 21#[command(
 22    name = "assistant_eval",
 23    disable_version_flag = true,
 24    before_help = "Tool eval runner"
 25)]
 26struct Args {
 27    /// Match the names of evals to run.
 28    #[arg(long)]
 29    exercise_names: Vec<String>,
 30    /// Runs all exercises, causes the exercise_names to be ignored.
 31    #[arg(long)]
 32    all: bool,
 33    /// Supported language types to evaluate (default: internal).
 34    /// Internal is data generated from the agent panel
 35    #[arg(long, default_value = "internal")]
 36    languages: String,
 37    /// Name of the model (default: "claude-3-7-sonnet-latest")
 38    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 39    model_name: String,
 40    /// Name of the judge model (default: value of `--model_name`).
 41    #[arg(long)]
 42    judge_model_name: Option<String>,
 43    /// Number of evaluations to run concurrently (default: 3)
 44    #[arg(short, long, default_value = "3")]
 45    concurrency: usize,
 46    /// Maximum number of exercises to evaluate per language
 47    #[arg(long)]
 48    max_exercises_per_language: Option<usize>,
 49}
 50
 51// First, let's define the order in which templates should be executed
 52const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
 53    "ProjectCreation",
 54    "CodeModification",
 55    "ConversationalGuidance",
 56];
 57
 58fn main() {
 59    env_logger::init();
 60    let args = Args::parse();
 61    let http_client = Arc::new(ReqwestClient::new());
 62    let app = Application::headless().with_http_client(http_client.clone());
 63
 64    // Path to the zed-ace-framework repo
 65    let framework_path = PathBuf::from("../zed-ace-framework")
 66        .canonicalize()
 67        .unwrap();
 68
 69    // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
 70    let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
 71
 72    println!("Using zed-ace-framework at: {:?}", framework_path);
 73    println!("Evaluating languages: {:?}", languages);
 74
 75    app.run(move |cx| {
 76        let app_state = headless_assistant::init(cx);
 77
 78        let model = find_model(&args.model_name, cx).unwrap();
 79        let judge_model = if let Some(model_name) = &args.judge_model_name {
 80            find_model(model_name, cx).unwrap()
 81        } else {
 82            model.clone()
 83        };
 84
 85        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 86            registry.set_default_model(Some(model.clone()), cx);
 87        });
 88
 89        let model_provider_id = model.provider_id();
 90        let judge_model_provider_id = judge_model.provider_id();
 91
 92        let framework_path_clone = framework_path.clone();
 93        let languages_clone = languages.clone();
 94        let exercise_names = args.exercise_names.clone();
 95        let all_flag = args.all;
 96
 97        cx.spawn(async move |cx| {
 98            // Authenticate all model providers first
 99            cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
100                .unwrap()
101                .await
102                .unwrap();
103            cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
104                .unwrap()
105                .await
106                .unwrap();
107
108            // Read base SHA from setup.json
109            let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
110
111            // Find all exercises for the specified languages
112            let all_exercises = find_exercises(
113                &framework_path_clone,
114                &languages_clone
115                    .iter()
116                    .map(|s| s.as_str())
117                    .collect::<Vec<_>>(),
118                args.max_exercises_per_language,
119            )
120            .unwrap();
121            println!("Found {} exercises total", all_exercises.len());
122
123            // Filter exercises if specific ones were requested
124            let exercises_to_run = if !exercise_names.is_empty() {
125                // If exercise names are specified, filter by them regardless of --all flag
126                all_exercises
127                    .into_iter()
128                    .filter(|path| {
129                        let name = get_exercise_name(path);
130                        exercise_names.iter().any(|filter| name.contains(filter))
131                    })
132                    .collect()
133            } else if all_flag {
134                // Only use all_flag if no exercise names are specified
135                all_exercises
136            } else {
137                // Default behavior (no filters)
138                all_exercises
139            };
140
141            println!("Will run {} exercises", exercises_to_run.len());
142
143            // Get all templates and sort them according to the execution order
144            let mut templates = all_templates();
145            templates.sort_by_key(|template| {
146                TEMPLATE_EXECUTION_ORDER
147                    .iter()
148                    .position(|&name| name == template.name)
149                    .unwrap_or(usize::MAX)
150            });
151
152            // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
153            let exercise_tasks: Vec<_> = exercises_to_run
154                .into_iter()
155                .map(|exercise_path| {
156                    let exercise_name = get_exercise_name(&exercise_path);
157                    let templates_clone = templates.clone();
158                    let model_clone = model.clone();
159                    let judge_model_clone = judge_model.clone();
160                    let app_state_clone = app_state.clone();
161                    let base_sha_clone = base_sha.clone();
162                    let framework_path_clone = framework_path_clone.clone();
163                    let cx_clone = cx.clone();
164
165                    async move {
166                        println!("Processing exercise: {}", exercise_name);
167                        let mut exercise_results = Vec::new();
168
169                        // Determine the language for this exercise
170                        let language = match get_exercise_language(&exercise_path) {
171                            Ok(lang) => lang,
172                            Err(err) => {
173                                println!(
174                                    "Error determining language for {}: {}",
175                                    exercise_name, err
176                                );
177                                return exercise_results;
178                            }
179                        };
180
181                        // Run each template sequentially for this exercise
182                        for template in templates_clone {
183                            // For "multi" or "internal" language, only run the CodeModification template
184                            if (language == "multi" || language == "internal")
185                                && template.name != "CodeModification"
186                            {
187                                println!(
188                                    "Skipping {} template for {} language",
189                                    template.name, language
190                                );
191                                continue;
192                            }
193
194                            match run_exercise_eval(
195                                exercise_path.clone(),
196                                template.clone(),
197                                model_clone.clone(),
198                                judge_model_clone.clone(),
199                                app_state_clone.clone(),
200                                base_sha_clone.clone(),
201                                framework_path_clone.clone(),
202                                cx_clone.clone(),
203                            )
204                            .await
205                            {
206                                Ok(result) => {
207                                    println!(
208                                        "Completed {} with template {} - score: {}",
209                                        exercise_name, template.name, result.score
210                                    );
211                                    exercise_results.push(result);
212                                }
213                                Err(err) => {
214                                    println!(
215                                        "Error running {} with template {}: {}",
216                                        exercise_name, template.name, err
217                                    );
218                                }
219                            }
220                        }
221
222                        // Save results for this exercise
223                        if !exercise_results.is_empty() {
224                            if let Err(err) =
225                                save_eval_results(&exercise_path, exercise_results.clone()).await
226                            {
227                                println!("Error saving results for {}: {}", exercise_name, err);
228                            } else {
229                                println!("Saved results for {}", exercise_name);
230                            }
231                        }
232
233                        exercise_results
234                    }
235                })
236                .collect();
237
238            println!(
239                "Running {} exercises with concurrency: {}",
240                exercise_tasks.len(),
241                args.concurrency
242            );
243
244            // Run exercises concurrently, with each exercise running its templates sequentially
245            let all_results = stream::iter(exercise_tasks)
246                .buffer_unordered(args.concurrency)
247                .flat_map(stream::iter)
248                .collect::<Vec<_>>()
249                .await;
250
251            println!("Completed {} evaluation runs", all_results.len());
252            cx.update(|cx| cx.quit()).unwrap();
253        })
254        .detach();
255    });
256
257    println!("Done running evals");
258}