1use crate::git_commands::{run_git, setup_temp_repo};
2use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
3use crate::{get_exercise_language, get_exercise_name};
4use agent::RequestKind;
5use anyhow::{Result, anyhow};
6use collections::HashMap;
7use gpui::{App, Task};
8use language_model::{LanguageModel, TokenUsage};
9use serde::{Deserialize, Serialize};
10use std::{
11 fs,
12 io::Write,
13 path::{Path, PathBuf},
14 sync::Arc,
15 time::{Duration, SystemTime},
16};
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct EvalResult {
20 pub exercise_name: String,
21 pub diff: String,
22 pub assistant_response: String,
23 pub elapsed_time_ms: u128,
24 pub timestamp: u128,
25 // Token usage fields
26 pub input_tokens: usize,
27 pub output_tokens: usize,
28 pub total_tokens: usize,
29 pub tool_use_counts: usize,
30}
31
32pub struct EvalOutput {
33 pub diff: String,
34 pub last_message: String,
35 pub elapsed_time: Duration,
36 pub assistant_response_count: usize,
37 pub tool_use_counts: HashMap<Arc<str>, u32>,
38 pub token_usage: TokenUsage,
39}
40
41#[derive(Deserialize)]
42pub struct EvalSetup {
43 pub url: String,
44 pub base_sha: String,
45}
46
47pub struct Eval {
48 pub repo_path: PathBuf,
49 pub eval_setup: EvalSetup,
50 pub user_prompt: String,
51}
52
53impl Eval {
54 // Keep this method for potential future use, but mark it as intentionally unused
55 #[allow(dead_code)]
56 pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
57 let prompt_path = path.join("prompt.txt");
58 let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
59 let setup_path = path.join("setup.json");
60 let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
61 let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
62
63 // Move this internal function inside the load method since it's only used here
64 fn repo_dir_name(url: &str) -> String {
65 url.trim_start_matches("https://")
66 .replace(|c: char| !c.is_alphanumeric(), "_")
67 }
68
69 let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
70
71 Ok(Eval {
72 repo_path,
73 eval_setup,
74 user_prompt,
75 })
76 }
77
78 pub fn run(
79 self,
80 app_state: Arc<HeadlessAppState>,
81 model: Arc<dyn LanguageModel>,
82 cx: &mut App,
83 ) -> Task<Result<EvalOutput>> {
84 cx.spawn(async move |cx| {
85 run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
86
87 let (assistant, done_rx) =
88 cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
89
90 let _worktree = assistant
91 .update(cx, |assistant, cx| {
92 assistant.project.update(cx, |project, cx| {
93 project.create_worktree(&self.repo_path, true, cx)
94 })
95 })?
96 .await?;
97
98 let start_time = std::time::SystemTime::now();
99
100 let (system_prompt_context, load_error) = cx
101 .update(|cx| {
102 assistant
103 .read(cx)
104 .thread
105 .read(cx)
106 .load_system_prompt_context(cx)
107 })?
108 .await;
109
110 if let Some(load_error) = load_error {
111 return Err(anyhow!("{:?}", load_error));
112 };
113
114 assistant.update(cx, |assistant, cx| {
115 assistant.thread.update(cx, |thread, cx| {
116 let context = vec![];
117 thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
118 thread.set_system_prompt_context(system_prompt_context);
119 thread.send_to_model(model, RequestKind::Chat, cx);
120 });
121 })?;
122
123 done_rx.recv().await??;
124
125 // Add this section to check untracked files
126 println!("Checking for untracked files:");
127 let untracked = run_git(
128 &self.repo_path,
129 &["ls-files", "--others", "--exclude-standard"],
130 )
131 .await?;
132 if untracked.is_empty() {
133 println!("No untracked files found");
134 } else {
135 // Add all files to git so they appear in the diff
136 println!("Adding untracked files to git");
137 run_git(&self.repo_path, &["add", "."]).await?;
138 }
139
140 // get git status
141 let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
142
143 let elapsed_time = start_time.elapsed()?;
144
145 // Get diff of staged changes (the files we just added)
146 let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
147
148 // Get diff of unstaged changes
149 let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
150
151 // Combine both diffs
152 let diff = if unstaged_diff.is_empty() {
153 staged_diff
154 } else if staged_diff.is_empty() {
155 unstaged_diff
156 } else {
157 format!(
158 "# Staged changes\n{}\n\n# Unstaged changes\n{}",
159 staged_diff, unstaged_diff
160 )
161 };
162
163 assistant.update(cx, |assistant, cx| {
164 let thread = assistant.thread.read(cx);
165 let last_message = thread.messages().last().unwrap();
166 if last_message.role != language_model::Role::Assistant {
167 return Err(anyhow!("Last message is not from assistant"));
168 }
169 let assistant_response_count = thread
170 .messages()
171 .filter(|message| message.role == language_model::Role::Assistant)
172 .count();
173 Ok(EvalOutput {
174 diff,
175 last_message: last_message.to_string(),
176 elapsed_time,
177 assistant_response_count,
178 tool_use_counts: assistant.tool_use_counts.clone(),
179 token_usage: thread.cumulative_token_usage(),
180 })
181 })?
182 })
183 }
184}
185
186impl EvalOutput {
187 // Keep this method for potential future use, but mark it as intentionally unused
188 #[allow(dead_code)]
189 pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
190 // Create the output directory if it doesn't exist
191 fs::create_dir_all(&output_dir)?;
192
193 // Save the diff to a file
194 let diff_path = output_dir.join("diff.patch");
195 let mut diff_file = fs::File::create(&diff_path)?;
196 diff_file.write_all(self.diff.as_bytes())?;
197
198 // Save the last message to a file
199 let message_path = output_dir.join("assistant_response.txt");
200 let mut message_file = fs::File::create(&message_path)?;
201 message_file.write_all(self.last_message.as_bytes())?;
202
203 // Current metrics for this run
204 let current_metrics = serde_json::json!({
205 "elapsed_time_ms": self.elapsed_time.as_millis(),
206 "assistant_response_count": self.assistant_response_count,
207 "tool_use_counts": self.tool_use_counts,
208 "token_usage": self.token_usage,
209 "eval_output_value": eval_output_value,
210 });
211
212 // Get current timestamp in milliseconds
213 let timestamp = std::time::SystemTime::now()
214 .duration_since(std::time::UNIX_EPOCH)?
215 .as_millis()
216 .to_string();
217
218 // Path to metrics file
219 let metrics_path = output_dir.join("metrics.json");
220
221 // Load existing metrics if the file exists, or create a new object
222 let mut historical_metrics = if metrics_path.exists() {
223 let metrics_content = fs::read_to_string(&metrics_path)?;
224 serde_json::from_str::<serde_json::Value>(&metrics_content)
225 .unwrap_or_else(|_| serde_json::json!({}))
226 } else {
227 serde_json::json!({})
228 };
229
230 // Add new run with timestamp as key
231 if let serde_json::Value::Object(ref mut map) = historical_metrics {
232 map.insert(timestamp, current_metrics);
233 }
234
235 // Write updated metrics back to file
236 let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
237 let mut metrics_file = fs::File::create(&metrics_path)?;
238 metrics_file.write_all(metrics_json.as_bytes())?;
239
240 Ok(())
241 }
242}
243
244pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
245 let instructions_path = exercise_path.join(".docs").join("instructions.md");
246 println!("Reading instructions from: {}", instructions_path.display());
247 let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
248 Ok(instructions)
249}
250
251pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
252 let eval_dir = exercise_path.join("evaluation");
253 fs::create_dir_all(&eval_dir)?;
254
255 let eval_file = eval_dir.join("evals.json");
256
257 println!("Saving evaluation results to: {}", eval_file.display());
258 println!(
259 "Results to save: {} evaluations for exercise path: {}",
260 results.len(),
261 exercise_path.display()
262 );
263
264 // Check file existence before reading/writing
265 if eval_file.exists() {
266 println!("Existing evals.json file found, will update it");
267 } else {
268 println!("No existing evals.json file found, will create new one");
269 }
270
271 // Structure to organize evaluations by test name and timestamp
272 let mut eval_data: serde_json::Value = if eval_file.exists() {
273 let content = fs::read_to_string(&eval_file)?;
274 serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
275 } else {
276 serde_json::json!({})
277 };
278
279 // Get current timestamp for this batch of results
280 let timestamp = SystemTime::now()
281 .duration_since(SystemTime::UNIX_EPOCH)?
282 .as_millis()
283 .to_string();
284
285 // Group the new results by test name (exercise name)
286 for result in results {
287 let exercise_name = &result.exercise_name;
288
289 println!("Adding result: exercise={}", exercise_name);
290
291 // Ensure the exercise entry exists
292 if eval_data.get(exercise_name).is_none() {
293 eval_data[exercise_name] = serde_json::json!({});
294 }
295
296 // Ensure the timestamp entry exists as an object
297 if eval_data[exercise_name].get(×tamp).is_none() {
298 eval_data[exercise_name][×tamp] = serde_json::json!({});
299 }
300
301 // Add this result under the timestamp with template name as key
302 eval_data[exercise_name][×tamp] = serde_json::to_value(&result)?;
303 }
304
305 // Write back to file with pretty formatting
306 let json_content = serde_json::to_string_pretty(&eval_data)?;
307 match fs::write(&eval_file, json_content) {
308 Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
309 Err(e) => println!("✗ Failed to write results file: {}", e),
310 }
311
312 Ok(())
313}
314
315pub async fn run_exercise_eval(
316 exercise_path: PathBuf,
317 model: Arc<dyn LanguageModel>,
318 app_state: Arc<HeadlessAppState>,
319 base_sha: String,
320 _framework_path: PathBuf,
321 cx: gpui::AsyncApp,
322) -> Result<EvalResult> {
323 let exercise_name = get_exercise_name(&exercise_path);
324 let language = get_exercise_language(&exercise_path)?;
325 let mut instructions = read_instructions(&exercise_path).await?;
326 instructions.push_str(&format!(
327 "\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
328 language
329 ));
330
331 println!("Running evaluation for exercise: {}", exercise_name);
332
333 // Create temporary directory with exercise files
334 let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
335 let temp_path = temp_dir.path().to_path_buf();
336
337 let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
338
339 let start_time = SystemTime::now();
340
341 // Create a basic eval struct to work with the existing system
342 let eval = Eval {
343 repo_path: temp_path.clone(),
344 eval_setup: EvalSetup {
345 url: format!("file://{}", temp_path.display()),
346 base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
347 },
348 user_prompt: instructions.clone(),
349 };
350
351 // Run the evaluation
352 let eval_output = cx
353 .update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
354 .await?;
355
356 // Get diff from git
357 let diff = eval_output.diff.clone();
358
359 let elapsed_time = start_time.elapsed()?;
360
361 // Calculate total tokens as the sum of input and output tokens
362 let input_tokens = eval_output.token_usage.input_tokens;
363 let output_tokens = eval_output.token_usage.output_tokens;
364 let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
365 let total_tokens = input_tokens + output_tokens;
366
367 // Save results to evaluation directory
368 let result = EvalResult {
369 exercise_name: exercise_name.clone(),
370 diff,
371 assistant_response: eval_output.last_message.clone(),
372 elapsed_time_ms: elapsed_time.as_millis(),
373 timestamp: SystemTime::now()
374 .duration_since(SystemTime::UNIX_EPOCH)?
375 .as_millis(),
376 // Convert u32 token counts to usize
377 input_tokens: input_tokens.try_into().unwrap(),
378 output_tokens: output_tokens.try_into().unwrap(),
379 total_tokens: total_tokens.try_into().unwrap(),
380 tool_use_counts: tool_use_counts.try_into().unwrap(),
381 };
382
383 Ok(result)
384}