1mod eval;
2mod get_exercise;
3mod git_commands;
4mod headless_assistant;
5mod judge;
6mod templates_eval;
7
8use clap::Parser;
9use eval::{run_exercise_eval, save_eval_results};
10use futures::stream::{self, StreamExt};
11use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
12use git_commands::read_base_sha;
13use gpui::Application;
14use headless_assistant::{authenticate_model_provider, find_model};
15use language_model::LanguageModelRegistry;
16use reqwest_client::ReqwestClient;
17use std::{path::PathBuf, sync::Arc};
18use templates_eval::all_templates;
19
20#[derive(Parser, Debug)]
21#[command(
22 name = "assistant_eval",
23 disable_version_flag = true,
24 before_help = "Tool eval runner"
25)]
26struct Args {
27 /// Match the names of evals to run.
28 #[arg(long)]
29 exercise_names: Vec<String>,
30 /// Runs all exercises, causes the exercise_names to be ignored.
31 #[arg(long)]
32 all: bool,
33 /// Supported language types to evaluate (default: internal).
34 /// Internal is data generated from the agent panel
35 #[arg(long, default_value = "internal")]
36 languages: String,
37 /// Name of the model (default: "claude-3-7-sonnet-latest")
38 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
39 model_name: String,
40 /// Name of the judge model (default: value of `--model_name`).
41 #[arg(long)]
42 judge_model_name: Option<String>,
43 /// Number of evaluations to run concurrently (default: 3)
44 #[arg(short, long, default_value = "3")]
45 concurrency: usize,
46 /// Maximum number of exercises to evaluate per language
47 #[arg(long)]
48 max_exercises_per_language: Option<usize>,
49}
50
51// First, let's define the order in which templates should be executed
52const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
53 "ProjectCreation",
54 "CodeModification",
55 "ConversationalGuidance",
56];
57
58fn main() {
59 env_logger::init();
60 let args = Args::parse();
61 let http_client = Arc::new(ReqwestClient::new());
62 let app = Application::headless().with_http_client(http_client.clone());
63
64 // Path to the zed-ace-framework repo
65 let framework_path = PathBuf::from("../zed-ace-framework")
66 .canonicalize()
67 .unwrap();
68
69 // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
70 let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
71
72 println!("Using zed-ace-framework at: {:?}", framework_path);
73 println!("Evaluating languages: {:?}", languages);
74
75 app.run(move |cx| {
76 let app_state = headless_assistant::init(cx);
77
78 let model = find_model(&args.model_name, cx).unwrap();
79 let judge_model = if let Some(model_name) = &args.judge_model_name {
80 find_model(model_name, cx).unwrap()
81 } else {
82 model.clone()
83 };
84
85 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
86 registry.set_default_model(Some(model.clone()), cx);
87 });
88
89 let model_provider_id = model.provider_id();
90 let judge_model_provider_id = judge_model.provider_id();
91
92 let framework_path_clone = framework_path.clone();
93 let languages_clone = languages.clone();
94 let exercise_names = args.exercise_names.clone();
95 let all_flag = args.all;
96
97 cx.spawn(async move |cx| {
98 // Authenticate all model providers first
99 cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
100 .unwrap()
101 .await
102 .unwrap();
103 cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
104 .unwrap()
105 .await
106 .unwrap();
107
108 // Read base SHA from setup.json
109 let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
110
111 // Find all exercises for the specified languages
112 let all_exercises = find_exercises(
113 &framework_path_clone,
114 &languages_clone
115 .iter()
116 .map(|s| s.as_str())
117 .collect::<Vec<_>>(),
118 args.max_exercises_per_language,
119 )
120 .unwrap();
121 println!("Found {} exercises total", all_exercises.len());
122
123 // Filter exercises if specific ones were requested
124 let exercises_to_run = if !exercise_names.is_empty() {
125 // If exercise names are specified, filter by them regardless of --all flag
126 all_exercises
127 .into_iter()
128 .filter(|path| {
129 let name = get_exercise_name(path);
130 exercise_names.iter().any(|filter| name.contains(filter))
131 })
132 .collect()
133 } else if all_flag {
134 // Only use all_flag if no exercise names are specified
135 all_exercises
136 } else {
137 // Default behavior (no filters)
138 all_exercises
139 };
140
141 println!("Will run {} exercises", exercises_to_run.len());
142
143 // Get all templates and sort them according to the execution order
144 let mut templates = all_templates();
145 templates.sort_by_key(|template| {
146 TEMPLATE_EXECUTION_ORDER
147 .iter()
148 .position(|&name| name == template.name)
149 .unwrap_or(usize::MAX)
150 });
151
152 // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
153 let exercise_tasks: Vec<_> = exercises_to_run
154 .into_iter()
155 .map(|exercise_path| {
156 let exercise_name = get_exercise_name(&exercise_path);
157 let templates_clone = templates.clone();
158 let model_clone = model.clone();
159 let judge_model_clone = judge_model.clone();
160 let app_state_clone = app_state.clone();
161 let base_sha_clone = base_sha.clone();
162 let framework_path_clone = framework_path_clone.clone();
163 let cx_clone = cx.clone();
164
165 async move {
166 println!("Processing exercise: {}", exercise_name);
167 let mut exercise_results = Vec::new();
168
169 // Determine the language for this exercise
170 let language = match get_exercise_language(&exercise_path) {
171 Ok(lang) => lang,
172 Err(err) => {
173 println!(
174 "Error determining language for {}: {}",
175 exercise_name, err
176 );
177 return exercise_results;
178 }
179 };
180
181 // Run each template sequentially for this exercise
182 for template in templates_clone {
183 // For "multi" or "internal" language, only run the CodeModification template
184 if (language == "multi" || language == "internal")
185 && template.name != "CodeModification"
186 {
187 println!(
188 "Skipping {} template for {} language",
189 template.name, language
190 );
191 continue;
192 }
193
194 match run_exercise_eval(
195 exercise_path.clone(),
196 template.clone(),
197 model_clone.clone(),
198 judge_model_clone.clone(),
199 app_state_clone.clone(),
200 base_sha_clone.clone(),
201 framework_path_clone.clone(),
202 cx_clone.clone(),
203 )
204 .await
205 {
206 Ok(result) => {
207 println!(
208 "Completed {} with template {} - score: {}",
209 exercise_name, template.name, result.score
210 );
211 exercise_results.push(result);
212 }
213 Err(err) => {
214 println!(
215 "Error running {} with template {}: {}",
216 exercise_name, template.name, err
217 );
218 }
219 }
220 }
221
222 // Save results for this exercise
223 if !exercise_results.is_empty() {
224 if let Err(err) =
225 save_eval_results(&exercise_path, exercise_results.clone()).await
226 {
227 println!("Error saving results for {}: {}", exercise_name, err);
228 } else {
229 println!("Saved results for {}", exercise_name);
230 }
231 }
232
233 exercise_results
234 }
235 })
236 .collect();
237
238 println!(
239 "Running {} exercises with concurrency: {}",
240 exercise_tasks.len(),
241 args.concurrency
242 );
243
244 // Run exercises concurrently, with each exercise running its templates sequentially
245 let all_results = stream::iter(exercise_tasks)
246 .buffer_unordered(args.concurrency)
247 .flat_map(stream::iter)
248 .collect::<Vec<_>>()
249 .await;
250
251 println!("Completed {} evaluation runs", all_results.len());
252 cx.update(|cx| cx.quit()).unwrap();
253 })
254 .detach();
255 });
256
257 println!("Done running evals");
258}