main.rs

  1mod eval;
  2mod get_exercise;
  3mod git_commands;
  4mod headless_assistant;
  5
  6use clap::Parser;
  7use eval::{run_exercise_eval, save_eval_results};
  8use futures::stream::{self, StreamExt};
  9use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
 10use git_commands::read_base_sha;
 11use gpui::Application;
 12use headless_assistant::{authenticate_model_provider, find_model};
 13use language_model::LanguageModelRegistry;
 14use reqwest_client::ReqwestClient;
 15use std::{path::PathBuf, sync::Arc};
 16
 17#[derive(Parser, Debug)]
 18#[command(
 19    name = "agent_eval",
 20    disable_version_flag = true,
 21    before_help = "Tool eval runner"
 22)]
 23struct Args {
 24    /// Match the names of evals to run.
 25    #[arg(long)]
 26    exercise_names: Vec<String>,
 27    /// Runs all exercises, causes the exercise_names to be ignored.
 28    #[arg(long)]
 29    all: bool,
 30    /// Supported language types to evaluate (default: internal).
 31    /// Internal is data generated from the agent panel
 32    #[arg(long, default_value = "internal")]
 33    languages: String,
 34    /// Name of the model (default: "claude-3-7-sonnet-latest")
 35    #[arg(long, default_value = "claude-3-7-sonnet-latest")]
 36    model_name: String,
 37    /// Name of the editor model (default: value of `--model_name`).
 38    #[arg(long)]
 39    editor_model_name: Option<String>,
 40    /// Number of evaluations to run concurrently (default: 3)
 41    #[arg(short, long, default_value = "5")]
 42    concurrency: usize,
 43    /// Maximum number of exercises to evaluate per language
 44    #[arg(long)]
 45    max_exercises_per_language: Option<usize>,
 46}
 47
 48fn main() {
 49    env_logger::init();
 50    let args = Args::parse();
 51    let http_client = Arc::new(ReqwestClient::new());
 52    let app = Application::headless().with_http_client(http_client.clone());
 53
 54    // Path to the zed-ace-framework repo
 55    let framework_path = PathBuf::from("../zed-ace-framework")
 56        .canonicalize()
 57        .unwrap();
 58
 59    // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
 60    let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
 61
 62    println!("Using zed-ace-framework at: {:?}", framework_path);
 63    println!("Evaluating languages: {:?}", languages);
 64
 65    app.run(move |cx| {
 66        let app_state = headless_assistant::init(cx);
 67
 68        let model = find_model(&args.model_name, cx).unwrap();
 69        let editor_model = if let Some(model_name) = &args.editor_model_name {
 70            find_model(model_name, cx).unwrap()
 71        } else {
 72            model.clone()
 73        };
 74
 75        LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
 76            registry.set_default_model(Some(model.clone()), cx);
 77        });
 78
 79        let model_provider_id = model.provider_id();
 80        let editor_model_provider_id = editor_model.provider_id();
 81
 82        let framework_path_clone = framework_path.clone();
 83        let languages_clone = languages.clone();
 84        let exercise_names = args.exercise_names.clone();
 85        let all_flag = args.all;
 86
 87        cx.spawn(async move |cx| {
 88            // Authenticate all model providers first
 89            cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
 90                .unwrap()
 91                .await
 92                .unwrap();
 93            cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
 94                .unwrap()
 95                .await
 96                .unwrap();
 97
 98            println!("framework path: {}", framework_path_clone.display());
 99
100            let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
101
102            println!("base sha: {}", base_sha);
103
104            let all_exercises = find_exercises(
105                &framework_path_clone,
106                &languages_clone
107                    .iter()
108                    .map(|s| s.as_str())
109                    .collect::<Vec<_>>(),
110                args.max_exercises_per_language,
111            )
112            .unwrap();
113            println!("Found {} exercises total", all_exercises.len());
114
115            // Filter exercises if specific ones were requested
116            let exercises_to_run = if !exercise_names.is_empty() {
117                // If exercise names are specified, filter by them regardless of --all flag
118                all_exercises
119                    .into_iter()
120                    .filter(|path| {
121                        let name = get_exercise_name(path);
122                        exercise_names.iter().any(|filter| name.contains(filter))
123                    })
124                    .collect()
125            } else if all_flag {
126                // Only use all_flag if no exercise names are specified
127                all_exercises
128            } else {
129                // Default behavior (no filters)
130                all_exercises
131            };
132
133            println!("Will run {} exercises", exercises_to_run.len());
134
135            // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
136            let exercise_tasks: Vec<_> = exercises_to_run
137                .into_iter()
138                .map(|exercise_path| {
139                    let exercise_name = get_exercise_name(&exercise_path);
140                    let model_clone = model.clone();
141                    let app_state_clone = app_state.clone();
142                    let base_sha_clone = base_sha.clone();
143                    let framework_path_clone = framework_path_clone.clone();
144                    let cx_clone = cx.clone();
145
146                    async move {
147                        println!("Processing exercise: {}", exercise_name);
148                        let mut exercise_results = Vec::new();
149
150                        match run_exercise_eval(
151                            exercise_path.clone(),
152                            model_clone.clone(),
153                            app_state_clone.clone(),
154                            base_sha_clone.clone(),
155                            framework_path_clone.clone(),
156                            cx_clone.clone(),
157                        )
158                        .await
159                        {
160                            Ok(result) => {
161                                println!("Completed {}", exercise_name);
162                                exercise_results.push(result);
163                            }
164                            Err(err) => {
165                                println!("Error running {}: {}", exercise_name, err);
166                            }
167                        }
168
169                        // Save results for this exercise
170                        if !exercise_results.is_empty() {
171                            if let Err(err) =
172                                save_eval_results(&exercise_path, exercise_results.clone()).await
173                            {
174                                println!("Error saving results for {}: {}", exercise_name, err);
175                            } else {
176                                println!("Saved results for {}", exercise_name);
177                            }
178                        }
179
180                        exercise_results
181                    }
182                })
183                .collect();
184
185            println!(
186                "Running {} exercises with concurrency: {}",
187                exercise_tasks.len(),
188                args.concurrency
189            );
190
191            // Run exercises concurrently, with each exercise running its templates sequentially
192            let all_results = stream::iter(exercise_tasks)
193                .buffer_unordered(args.concurrency)
194                .flat_map(stream::iter)
195                .collect::<Vec<_>>()
196                .await;
197
198            println!("Completed {} evaluation runs", all_results.len());
199            cx.update(|cx| cx.quit()).unwrap();
200        })
201        .detach();
202    });
203
204    println!("Done running evals");
205}