1mod eval;
2mod get_exercise;
3mod git_commands;
4mod headless_assistant;
5mod judge;
6mod templates_eval;
7
8use clap::Parser;
9use eval::{run_exercise_eval, save_eval_results};
10use futures::stream::{self, StreamExt};
11use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
12use git_commands::read_base_sha;
13use gpui::Application;
14use headless_assistant::{authenticate_model_provider, find_model};
15use language_model::LanguageModelRegistry;
16use reqwest_client::ReqwestClient;
17use std::{path::PathBuf, sync::Arc};
18use templates_eval::all_templates;
19
20#[derive(Parser, Debug)]
21#[command(
22 name = "assistant_eval",
23 disable_version_flag = true,
24 before_help = "Tool eval runner"
25)]
26struct Args {
27 /// Match the names of evals to run.
28 #[arg(long)]
29 exercise_names: Vec<String>,
30 /// Runs all exercises, causes the exercise_names to be ignored.
31 #[arg(long)]
32 all: bool,
33 /// Supported language types to evaluate (default: internal).
34 /// Internal is data generated from the agent panel
35 #[arg(long, default_value = "internal")]
36 languages: String,
37 /// Name of the model (default: "claude-3-7-sonnet-latest")
38 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
39 model_name: String,
40 /// Name of the editor model (default: value of `--model_name`).
41 #[arg(long)]
42 editor_model_name: Option<String>,
43 /// Name of the judge model (default: value of `--model_name`).
44 #[arg(long)]
45 judge_model_name: Option<String>,
46 /// Number of evaluations to run concurrently (default: 3)
47 #[arg(short, long, default_value = "3")]
48 concurrency: usize,
49 /// Maximum number of exercises to evaluate per language
50 #[arg(long)]
51 max_exercises_per_language: Option<usize>,
52}
53
54// First, let's define the order in which templates should be executed
55const TEMPLATE_EXECUTION_ORDER: [&str; 3] = [
56 "ProjectCreation",
57 "CodeModification",
58 "ConversationalGuidance",
59];
60
61fn main() {
62 env_logger::init();
63 let args = Args::parse();
64 let http_client = Arc::new(ReqwestClient::new());
65 let app = Application::headless().with_http_client(http_client.clone());
66
67 // Path to the zed-ace-framework repo
68 let framework_path = PathBuf::from("../zed-ace-framework")
69 .canonicalize()
70 .unwrap();
71
72 // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
73 let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
74
75 println!("Using zed-ace-framework at: {:?}", framework_path);
76 println!("Evaluating languages: {:?}", languages);
77
78 app.run(move |cx| {
79 let app_state = headless_assistant::init(cx);
80
81 let model = find_model(&args.model_name, cx).unwrap();
82 let editor_model = if let Some(model_name) = &args.editor_model_name {
83 find_model(model_name, cx).unwrap()
84 } else {
85 model.clone()
86 };
87 let judge_model = if let Some(model_name) = &args.judge_model_name {
88 find_model(model_name, cx).unwrap()
89 } else {
90 model.clone()
91 };
92
93 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
94 registry.set_active_model(Some(model.clone()), cx);
95 registry.set_editor_model(Some(editor_model.clone()), cx);
96 });
97
98 let model_provider_id = model.provider_id();
99 let editor_model_provider_id = editor_model.provider_id();
100 let judge_model_provider_id = judge_model.provider_id();
101
102 let framework_path_clone = framework_path.clone();
103 let languages_clone = languages.clone();
104 let exercise_names = args.exercise_names.clone();
105 let all_flag = args.all;
106
107 cx.spawn(async move |cx| {
108 // Authenticate all model providers first
109 cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
110 .unwrap()
111 .await
112 .unwrap();
113 cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
114 .unwrap()
115 .await
116 .unwrap();
117 cx.update(|cx| authenticate_model_provider(judge_model_provider_id.clone(), cx))
118 .unwrap()
119 .await
120 .unwrap();
121
122 // Read base SHA from setup.json
123 let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
124
125 // Find all exercises for the specified languages
126 let all_exercises = find_exercises(
127 &framework_path_clone,
128 &languages_clone
129 .iter()
130 .map(|s| s.as_str())
131 .collect::<Vec<_>>(),
132 args.max_exercises_per_language,
133 )
134 .unwrap();
135 println!("Found {} exercises total", all_exercises.len());
136
137 // Filter exercises if specific ones were requested
138 let exercises_to_run = if !exercise_names.is_empty() {
139 // If exercise names are specified, filter by them regardless of --all flag
140 all_exercises
141 .into_iter()
142 .filter(|path| {
143 let name = get_exercise_name(path);
144 exercise_names.iter().any(|filter| name.contains(filter))
145 })
146 .collect()
147 } else if all_flag {
148 // Only use all_flag if no exercise names are specified
149 all_exercises
150 } else {
151 // Default behavior (no filters)
152 all_exercises
153 };
154
155 println!("Will run {} exercises", exercises_to_run.len());
156
157 // Get all templates and sort them according to the execution order
158 let mut templates = all_templates();
159 templates.sort_by_key(|template| {
160 TEMPLATE_EXECUTION_ORDER
161 .iter()
162 .position(|&name| name == template.name)
163 .unwrap_or(usize::MAX)
164 });
165
166 // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
167 let exercise_tasks: Vec<_> = exercises_to_run
168 .into_iter()
169 .map(|exercise_path| {
170 let exercise_name = get_exercise_name(&exercise_path);
171 let templates_clone = templates.clone();
172 let model_clone = model.clone();
173 let judge_model_clone = judge_model.clone();
174 let app_state_clone = app_state.clone();
175 let base_sha_clone = base_sha.clone();
176 let framework_path_clone = framework_path_clone.clone();
177 let cx_clone = cx.clone();
178
179 async move {
180 println!("Processing exercise: {}", exercise_name);
181 let mut exercise_results = Vec::new();
182
183 // Determine the language for this exercise
184 let language = match get_exercise_language(&exercise_path) {
185 Ok(lang) => lang,
186 Err(err) => {
187 println!(
188 "Error determining language for {}: {}",
189 exercise_name, err
190 );
191 return exercise_results;
192 }
193 };
194
195 // Run each template sequentially for this exercise
196 for template in templates_clone {
197 // For "multi" or "internal" language, only run the CodeModification template
198 if (language == "multi" || language == "internal")
199 && template.name != "CodeModification"
200 {
201 println!(
202 "Skipping {} template for {} language",
203 template.name, language
204 );
205 continue;
206 }
207
208 match run_exercise_eval(
209 exercise_path.clone(),
210 template.clone(),
211 model_clone.clone(),
212 judge_model_clone.clone(),
213 app_state_clone.clone(),
214 base_sha_clone.clone(),
215 framework_path_clone.clone(),
216 cx_clone.clone(),
217 )
218 .await
219 {
220 Ok(result) => {
221 println!(
222 "Completed {} with template {} - score: {}",
223 exercise_name, template.name, result.score
224 );
225 exercise_results.push(result);
226 }
227 Err(err) => {
228 println!(
229 "Error running {} with template {}: {}",
230 exercise_name, template.name, err
231 );
232 }
233 }
234 }
235
236 // Save results for this exercise
237 if !exercise_results.is_empty() {
238 if let Err(err) =
239 save_eval_results(&exercise_path, exercise_results.clone()).await
240 {
241 println!("Error saving results for {}: {}", exercise_name, err);
242 } else {
243 println!("Saved results for {}", exercise_name);
244 }
245 }
246
247 exercise_results
248 }
249 })
250 .collect();
251
252 println!(
253 "Running {} exercises with concurrency: {}",
254 exercise_tasks.len(),
255 args.concurrency
256 );
257
258 // Run exercises concurrently, with each exercise running its templates sequentially
259 let all_results = stream::iter(exercise_tasks)
260 .buffer_unordered(args.concurrency)
261 .flat_map(stream::iter)
262 .collect::<Vec<_>>()
263 .await;
264
265 println!("Completed {} evaluation runs", all_results.len());
266 cx.update(|cx| cx.quit()).unwrap();
267 })
268 .detach();
269 });
270
271 println!("Done running evals");
272}