main.rs

  1mod eval;
  2mod get_exercise;
  3mod git_commands;
  4mod headless_assistant;
  5mod judge;
  6mod templates_eval;
  7
  8use clap::Parser;
  9use eval::{run_exercise_eval, save_eval_results};
 10use futures::stream::{self, StreamExt};
 11use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
 12use git_commands::read_base_sha;
 13use gpui::Application;
 14use headless_assistant::{authenticate_model_provider, find_model};
 15use language_model::LanguageModelRegistry;
 16use reqwest_client::ReqwestClient;
 17use std::{path::PathBuf, sync::Arc};
 18use templates_eval::all_templates;
 19
 20#[derive(Parser, Debug)]
 21#[command(
 22    name = "assistant_eval",
 23    disable_version_flag = true,
 24    before_help = "Tool eval runner"
 25)]
 26struct Args {
 27    /// Match the names of evals to run.
 28    #[arg(long)]
 29    exercise_names: Vec<String>,
 30    /// Runs all exercises, causes the exercise_names to be ignored.
 31    #[arg(long)]
 32    all: bool,
 33    /// Supported language types to evaluate (default: internal).
 34    /// Internal is data generated from the agent panel
 35    #[arg(long, default_value = "internal")]
 36    languages: String,
 37    /// Name of the model (default: "claude-3-7-sonnet-latest")
 38    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 39    model_name: String,
 40    /// Name of the editor model (default: value of `--model_name`).
 41    #[arg(long)]
 42    editor_model_name: Option<String>,
 43    /// Name of the judge model (default: value of `--model_name`).
 44    #[arg(long)]
 45    judge_model_name: Option<String>,
 46    /// Number of evaluations to run concurrently (default: 3)
 47    #[arg(short, long, default_value = "3")]
 48    concurrency: usize,
 49    /// Maximum number of exercises to evaluate per language
 50    #[arg(long)]
 51    max_exercises_per_language: Option<usize>,
 52}
 53
 54// First, let's define the order in which templates should be executed
 55const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
 56    "ProjectCreation",
 57    "CodeModification",
 58    "ConversationalGuidance",
 59];
 60
 61fn main() {
 62    env_logger::init();
 63    let args = Args::parse();
 64    let http_client = Arc::new(ReqwestClient::new());
 65    let app = Application::headless().with_http_client(http_client.clone());
 66
 67    // Path to the zed-ace-framework repo
 68    let framework_path = PathBuf::from("../zed-ace-framework")
 69        .canonicalize()
 70        .unwrap();
 71
 72    // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
 73    let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
 74
 75    println!("Using zed-ace-framework at: {:?}", framework_path);
 76    println!("Evaluating languages: {:?}", languages);
 77
 78    app.run(move |cx| {
 79        let app_state = headless_assistant::init(cx);
 80
 81        let model = find_model(&args.model_name, cx).unwrap();
 82        let editor_model = if let Some(model_name) = &args.editor_model_name {
 83            find_model(model_name, cx).unwrap()
 84        } else {
 85            model.clone()
 86        };
 87        let judge_model = if let Some(model_name) = &args.judge_model_name {
 88            find_model(model_name, cx).unwrap()
 89        } else {
 90            model.clone()
 91        };
 92
 93        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 94            registry.set_active_model(Some(model.clone()), cx);
 95            registry.set_editor_model(Some(editor_model.clone()), cx);
 96        });
 97
 98        let model_provider_id = model.provider_id();
 99        let editor_model_provider_id = editor_model.provider_id();
100        let judge_model_provider_id = judge_model.provider_id();
101
102        let framework_path_clone = framework_path.clone();
103        let languages_clone = languages.clone();
104        let exercise_names = args.exercise_names.clone();
105        let all_flag = args.all;
106
107        cx.spawn(async move |cx| {
108            // Authenticate all model providers first
109            cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
110                .unwrap()
111                .await
112                .unwrap();
113            cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
114                .unwrap()
115                .await
116                .unwrap();
117            cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
118                .unwrap()
119                .await
120                .unwrap();
121
122            // Read base SHA from setup.json
123            let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
124
125            // Find all exercises for the specified languages
126            let all_exercises = find_exercises(
127                &framework_path_clone,
128                &languages_clone
129                    .iter()
130                    .map(|s| s.as_str())
131                    .collect::<Vec<_>>(),
132                args.max_exercises_per_language,
133            )
134            .unwrap();
135            println!("Found {} exercises total", all_exercises.len());
136
137            // Filter exercises if specific ones were requested
138            let exercises_to_run = if !exercise_names.is_empty() {
139                // If exercise names are specified, filter by them regardless of --all flag
140                all_exercises
141                    .into_iter()
142                    .filter(|path| {
143                        let name = get_exercise_name(path);
144                        exercise_names.iter().any(|filter| name.contains(filter))
145                    })
146                    .collect()
147            } else if all_flag {
148                // Only use all_flag if no exercise names are specified
149                all_exercises
150            } else {
151                // Default behavior (no filters)
152                all_exercises
153            };
154
155            println!("Will run {} exercises", exercises_to_run.len());
156
157            // Get all templates and sort them according to the execution order
158            let mut templates = all_templates();
159            templates.sort_by_key(|template| {
160                TEMPLATE_EXECUTION_ORDER
161                    .iter()
162                    .position(|&name| name == template.name)
163                    .unwrap_or(usize::MAX)
164            });
165
166            // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
167            let exercise_tasks: Vec<_> = exercises_to_run
168                .into_iter()
169                .map(|exercise_path| {
170                    let exercise_name = get_exercise_name(&exercise_path);
171                    let templates_clone = templates.clone();
172                    let model_clone = model.clone();
173                    let judge_model_clone = judge_model.clone();
174                    let app_state_clone = app_state.clone();
175                    let base_sha_clone = base_sha.clone();
176                    let framework_path_clone = framework_path_clone.clone();
177                    let cx_clone = cx.clone();
178
179                    async move {
180                        println!("Processing exercise: {}", exercise_name);
181                        let mut exercise_results = Vec::new();
182
183                        // Determine the language for this exercise
184                        let language = match get_exercise_language(&exercise_path) {
185                            Ok(lang) => lang,
186                            Err(err) => {
187                                println!(
188                                    "Error determining language for {}: {}",
189                                    exercise_name, err
190                                );
191                                return exercise_results;
192                            }
193                        };
194
195                        // Run each template sequentially for this exercise
196                        for template in templates_clone {
197                            // For "multi" or "internal" language, only run the CodeModification template
198                            if (language == "multi" || language == "internal")
199                                && template.name != "CodeModification"
200                            {
201                                println!(
202                                    "Skipping {} template for {} language",
203                                    template.name, language
204                                );
205                                continue;
206                            }
207
208                            match run_exercise_eval(
209                                exercise_path.clone(),
210                                template.clone(),
211                                model_clone.clone(),
212                                judge_model_clone.clone(),
213                                app_state_clone.clone(),
214                                base_sha_clone.clone(),
215                                framework_path_clone.clone(),
216                                cx_clone.clone(),
217                            )
218                            .await
219                            {
220                                Ok(result) => {
221                                    println!(
222                                        "Completed {} with template {} - score: {}",
223                                        exercise_name, template.name, result.score
224                                    );
225                                    exercise_results.push(result);
226                                }
227                                Err(err) => {
228                                    println!(
229                                        "Error running {} with template {}: {}",
230                                        exercise_name, template.name, err
231                                    );
232                                }
233                            }
234                        }
235
236                        // Save results for this exercise
237                        if !exercise_results.is_empty() {
238                            if let Err(err) =
239                                save_eval_results(&exercise_path, exercise_results.clone()).await
240                            {
241                                println!("Error saving results for {}: {}", exercise_name, err);
242                            } else {
243                                println!("Saved results for {}", exercise_name);
244                            }
245                        }
246
247                        exercise_results
248                    }
249                })
250                .collect();
251
252            println!(
253                "Running {} exercises with concurrency: {}",
254                exercise_tasks.len(),
255                args.concurrency
256            );
257
258            // Run exercises concurrently, with each exercise running its templates sequentially
259            let all_results = stream::iter(exercise_tasks)
260                .buffer_unordered(args.concurrency)
261                .flat_map(stream::iter)
262                .collect::<Vec<_>>()
263                .await;
264
265            println!("Completed {} evaluation runs", all_results.len());
266            cx.update(|cx| cx.quit()).unwrap();
267        })
268        .detach();
269    });
270
271    println!("Done running evals");
272}