1mod eval;
2mod get_exercise;
3mod git_commands;
4mod headless_assistant;
5
6use clap::Parser;
7use eval::{run_exercise_eval, save_eval_results};
8use futures::stream::{self, StreamExt};
9use get_exercise::{find_exercises, get_exercise_language, get_exercise_name};
10use git_commands::read_base_sha;
11use gpui::Application;
12use headless_assistant::{authenticate_model_provider, find_model};
13use language_model::LanguageModelRegistry;
14use reqwest_client::ReqwestClient;
15use std::{path::PathBuf, sync::Arc};
16
17#[derive(Parser, Debug)]
18#[command(
19 name = "agent_eval",
20 disable_version_flag = true,
21 before_help = "Tool eval runner"
22)]
23struct Args {
24 /// Match the names of evals to run.
25 #[arg(long)]
26 exercise_names: Vec<String>,
27 /// Runs all exercises, causes the exercise_names to be ignored.
28 #[arg(long)]
29 all: bool,
30 /// Supported language types to evaluate (default: internal).
31 /// Internal is data generated from the agent panel
32 #[arg(long, default_value = "internal")]
33 languages: String,
34 /// Name of the model (default: "claude-3-7-sonnet-latest")
35 #[arg(long, default_value = "claude-3-7-sonnet-latest")]
36 model_name: String,
37 /// Name of the editor model (default: value of `--model_name`).
38 #[arg(long)]
39 editor_model_name: Option<String>,
40 /// Number of evaluations to run concurrently (default: 3)
41 #[arg(short, long, default_value = "5")]
42 concurrency: usize,
43 /// Maximum number of exercises to evaluate per language
44 #[arg(long)]
45 max_exercises_per_language: Option<usize>,
46}
47
48fn main() {
49 env_logger::init();
50 let args = Args::parse();
51 let http_client = Arc::new(ReqwestClient::new());
52 let app = Application::headless().with_http_client(http_client.clone());
53
54 // Path to the zed-ace-framework repo
55 let framework_path = PathBuf::from("../zed-ace-framework")
56 .canonicalize()
57 .unwrap();
58
59 // Fix the 'languages' lifetime issue by creating owned Strings instead of slices
60 let languages: Vec<String> = args.languages.split(',').map(|s| s.to_string()).collect();
61
62 println!("Using zed-ace-framework at: {:?}", framework_path);
63 println!("Evaluating languages: {:?}", languages);
64
65 app.run(move |cx| {
66 let app_state = headless_assistant::init(cx);
67
68 let model = find_model(&args.model_name, cx).unwrap();
69 let editor_model = if let Some(model_name) = &args.editor_model_name {
70 find_model(model_name, cx).unwrap()
71 } else {
72 model.clone()
73 };
74
75 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
76 registry.set_default_model(Some(model.clone()), cx);
77 });
78
79 let model_provider_id = model.provider_id();
80 let editor_model_provider_id = editor_model.provider_id();
81
82 let framework_path_clone = framework_path.clone();
83 let languages_clone = languages.clone();
84 let exercise_names = args.exercise_names.clone();
85 let all_flag = args.all;
86
87 cx.spawn(async move |cx| {
88 // Authenticate all model providers first
89 cx.update(|cx| authenticate_model_provider(model_provider_id.clone(), cx))
90 .unwrap()
91 .await
92 .unwrap();
93 cx.update(|cx| authenticate_model_provider(editor_model_provider_id.clone(), cx))
94 .unwrap()
95 .await
96 .unwrap();
97
98 println!("framework path: {}", framework_path_clone.display());
99
100 let base_sha = read_base_sha(&framework_path_clone).await.unwrap();
101
102 println!("base sha: {}", base_sha);
103
104 let all_exercises = find_exercises(
105 &framework_path_clone,
106 &languages_clone
107 .iter()
108 .map(|s| s.as_str())
109 .collect::<Vec<_>>(),
110 args.max_exercises_per_language,
111 )
112 .unwrap();
113 println!("Found {} exercises total", all_exercises.len());
114
115 // Filter exercises if specific ones were requested
116 let exercises_to_run = if !exercise_names.is_empty() {
117 // If exercise names are specified, filter by them regardless of --all flag
118 all_exercises
119 .into_iter()
120 .filter(|path| {
121 let name = get_exercise_name(path);
122 exercise_names.iter().any(|filter| name.contains(filter))
123 })
124 .collect()
125 } else if all_flag {
126 // Only use all_flag if no exercise names are specified
127 all_exercises
128 } else {
129 // Default behavior (no filters)
130 all_exercises
131 };
132
133 println!("Will run {} exercises", exercises_to_run.len());
134
135 // Create exercise eval tasks - each exercise is a single task that will run templates sequentially
136 let exercise_tasks: Vec<_> = exercises_to_run
137 .into_iter()
138 .map(|exercise_path| {
139 let exercise_name = get_exercise_name(&exercise_path);
140 let model_clone = model.clone();
141 let app_state_clone = app_state.clone();
142 let base_sha_clone = base_sha.clone();
143 let framework_path_clone = framework_path_clone.clone();
144 let cx_clone = cx.clone();
145
146 async move {
147 println!("Processing exercise: {}", exercise_name);
148 let mut exercise_results = Vec::new();
149
150 match run_exercise_eval(
151 exercise_path.clone(),
152 model_clone.clone(),
153 app_state_clone.clone(),
154 base_sha_clone.clone(),
155 framework_path_clone.clone(),
156 cx_clone.clone(),
157 )
158 .await
159 {
160 Ok(result) => {
161 println!("Completed {}", exercise_name);
162 exercise_results.push(result);
163 }
164 Err(err) => {
165 println!("Error running {}: {}", exercise_name, err);
166 }
167 }
168
169 // Save results for this exercise
170 if !exercise_results.is_empty() {
171 if let Err(err) =
172 save_eval_results(&exercise_path, exercise_results.clone()).await
173 {
174 println!("Error saving results for {}: {}", exercise_name, err);
175 } else {
176 println!("Saved results for {}", exercise_name);
177 }
178 }
179
180 exercise_results
181 }
182 })
183 .collect();
184
185 println!(
186 "Running {} exercises with concurrency: {}",
187 exercise_tasks.len(),
188 args.concurrency
189 );
190
191 // Run exercises concurrently, with each exercise running its templates sequentially
192 let all_results = stream::iter(exercise_tasks)
193 .buffer_unordered(args.concurrency)
194 .flat_map(stream::iter)
195 .collect::<Vec<_>>()
196 .await;
197
198 println!("Completed {} evaluation runs", all_results.len());
199 cx.update(|cx| cx.quit()).unwrap();
200 })
201 .detach();
202 });
203
204 println!("Done running evals");
205}