eval.rs

  1use crate::git_commands::{run_git, setup_temp_repo};
  2use crate::headless_assistant::{HeadlessAppState, HeadlessAssistant};
  3use crate::{get_exercise_language, get_exercise_name};
  4use agent::RequestKind;
  5use anyhow::{Result, anyhow};
  6use collections::HashMap;
  7use gpui::{App, Task};
  8use language_model::{LanguageModel, TokenUsage};
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    fs,
 12    io::Write,
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15    time::{Duration, SystemTime},
 16};
 17
 18#[derive(Debug, Serialize, Deserialize, Clone)]
 19pub struct EvalResult {
 20    pub exercise_name: String,
 21    pub diff: String,
 22    pub assistant_response: String,
 23    pub elapsed_time_ms: u128,
 24    pub timestamp: u128,
 25    // Token usage fields
 26    pub input_tokens: usize,
 27    pub output_tokens: usize,
 28    pub total_tokens: usize,
 29    pub tool_use_counts: usize,
 30}
 31
 32pub struct EvalOutput {
 33    pub diff: String,
 34    pub last_message: String,
 35    pub elapsed_time: Duration,
 36    pub assistant_response_count: usize,
 37    pub tool_use_counts: HashMap<Arc<str>, u32>,
 38    pub token_usage: TokenUsage,
 39}
 40
 41#[derive(Deserialize)]
 42pub struct EvalSetup {
 43    pub url: String,
 44    pub base_sha: String,
 45}
 46
 47pub struct Eval {
 48    pub repo_path: PathBuf,
 49    pub eval_setup: EvalSetup,
 50    pub user_prompt: String,
 51}
 52
 53impl Eval {
 54    // Keep this method for potential future use, but mark it as intentionally unused
 55    #[allow(dead_code)]
 56    pub async fn load(_name: String, path: PathBuf, repos_dir: &Path) -> Result<Self> {
 57        let prompt_path = path.join("prompt.txt");
 58        let user_prompt = smol::unblock(|| std::fs::read_to_string(prompt_path)).await?;
 59        let setup_path = path.join("setup.json");
 60        let setup_contents = smol::unblock(|| std::fs::read_to_string(setup_path)).await?;
 61        let eval_setup = serde_json_lenient::from_str_lenient::<EvalSetup>(&setup_contents)?;
 62
 63        // Move this internal function inside the load method since it's only used here
 64        fn repo_dir_name(url: &str) -> String {
 65            url.trim_start_matches("https://")
 66                .replace(|c: char| !c.is_alphanumeric(), "_")
 67        }
 68
 69        let repo_path = repos_dir.join(repo_dir_name(&eval_setup.url));
 70
 71        Ok(Eval {
 72            repo_path,
 73            eval_setup,
 74            user_prompt,
 75        })
 76    }
 77
 78    pub fn run(
 79        self,
 80        app_state: Arc<HeadlessAppState>,
 81        model: Arc<dyn LanguageModel>,
 82        cx: &mut App,
 83    ) -> Task<Result<EvalOutput>> {
 84        cx.spawn(async move |cx| {
 85            run_git(&self.repo_path, &["checkout", &self.eval_setup.base_sha]).await?;
 86
 87            let (assistant, done_rx) =
 88                cx.update(|cx| HeadlessAssistant::new(app_state.clone(), cx))??;
 89
 90            let _worktree = assistant
 91                .update(cx, |assistant, cx| {
 92                    assistant.project.update(cx, |project, cx| {
 93                        project.create_worktree(&self.repo_path, true, cx)
 94                    })
 95                })?
 96                .await?;
 97
 98            let start_time = std::time::SystemTime::now();
 99
100            let (system_prompt_context, load_error) = cx
101                .update(|cx| {
102                    assistant
103                        .read(cx)
104                        .thread
105                        .read(cx)
106                        .load_system_prompt_context(cx)
107                })?
108                .await;
109
110            if let Some(load_error) = load_error {
111                return Err(anyhow!("{:?}", load_error));
112            };
113
114            assistant.update(cx, |assistant, cx| {
115                assistant.thread.update(cx, |thread, cx| {
116                    let context = vec![];
117                    thread.insert_user_message(self.user_prompt.clone(), context, None, cx);
118                    thread.set_system_prompt_context(system_prompt_context);
119                    thread.send_to_model(model, RequestKind::Chat, cx);
120                });
121            })?;
122
123            done_rx.recv().await??;
124
125            // Add this section to check untracked files
126            println!("Checking for untracked files:");
127            let untracked = run_git(
128                &self.repo_path,
129                &["ls-files", "--others", "--exclude-standard"],
130            )
131            .await?;
132            if untracked.is_empty() {
133                println!("No untracked files found");
134            } else {
135                // Add all files to git so they appear in the diff
136                println!("Adding untracked files to git");
137                run_git(&self.repo_path, &["add", "."]).await?;
138            }
139
140            // get git status
141            let _status = run_git(&self.repo_path, &["status", "--short"]).await?;
142
143            let elapsed_time = start_time.elapsed()?;
144
145            // Get diff of staged changes (the files we just added)
146            let staged_diff = run_git(&self.repo_path, &["diff", "--staged"]).await?;
147
148            // Get diff of unstaged changes
149            let unstaged_diff = run_git(&self.repo_path, &["diff"]).await?;
150
151            // Combine both diffs
152            let diff = if unstaged_diff.is_empty() {
153                staged_diff
154            } else if staged_diff.is_empty() {
155                unstaged_diff
156            } else {
157                format!(
158                    "# Staged changes\n{}\n\n# Unstaged changes\n{}",
159                    staged_diff, unstaged_diff
160                )
161            };
162
163            assistant.update(cx, |assistant, cx| {
164                let thread = assistant.thread.read(cx);
165                let last_message = thread.messages().last().unwrap();
166                if last_message.role != language_model::Role::Assistant {
167                    return Err(anyhow!("Last message is not from assistant"));
168                }
169                let assistant_response_count = thread
170                    .messages()
171                    .filter(|message| message.role == language_model::Role::Assistant)
172                    .count();
173                Ok(EvalOutput {
174                    diff,
175                    last_message: last_message.to_string(),
176                    elapsed_time,
177                    assistant_response_count,
178                    tool_use_counts: assistant.tool_use_counts.clone(),
179                    token_usage: thread.cumulative_token_usage(),
180                })
181            })?
182        })
183    }
184}
185
186impl EvalOutput {
187    // Keep this method for potential future use, but mark it as intentionally unused
188    #[allow(dead_code)]
189    pub fn save_to_directory(&self, output_dir: &Path, eval_output_value: String) -> Result<()> {
190        // Create the output directory if it doesn't exist
191        fs::create_dir_all(&output_dir)?;
192
193        // Save the diff to a file
194        let diff_path = output_dir.join("diff.patch");
195        let mut diff_file = fs::File::create(&diff_path)?;
196        diff_file.write_all(self.diff.as_bytes())?;
197
198        // Save the last message to a file
199        let message_path = output_dir.join("assistant_response.txt");
200        let mut message_file = fs::File::create(&message_path)?;
201        message_file.write_all(self.last_message.as_bytes())?;
202
203        // Current metrics for this run
204        let current_metrics = serde_json::json!({
205            "elapsed_time_ms": self.elapsed_time.as_millis(),
206            "assistant_response_count": self.assistant_response_count,
207            "tool_use_counts": self.tool_use_counts,
208            "token_usage": self.token_usage,
209            "eval_output_value": eval_output_value,
210        });
211
212        // Get current timestamp in milliseconds
213        let timestamp = std::time::SystemTime::now()
214            .duration_since(std::time::UNIX_EPOCH)?
215            .as_millis()
216            .to_string();
217
218        // Path to metrics file
219        let metrics_path = output_dir.join("metrics.json");
220
221        // Load existing metrics if the file exists, or create a new object
222        let mut historical_metrics = if metrics_path.exists() {
223            let metrics_content = fs::read_to_string(&metrics_path)?;
224            serde_json::from_str::<serde_json::Value>(&metrics_content)
225                .unwrap_or_else(|_| serde_json::json!({}))
226        } else {
227            serde_json::json!({})
228        };
229
230        // Add new run with timestamp as key
231        if let serde_json::Value::Object(ref mut map) = historical_metrics {
232            map.insert(timestamp, current_metrics);
233        }
234
235        // Write updated metrics back to file
236        let metrics_json = serde_json::to_string_pretty(&historical_metrics)?;
237        let mut metrics_file = fs::File::create(&metrics_path)?;
238        metrics_file.write_all(metrics_json.as_bytes())?;
239
240        Ok(())
241    }
242}
243
244pub async fn read_instructions(exercise_path: &Path) -> Result<String> {
245    let instructions_path = exercise_path.join(".docs").join("instructions.md");
246    println!("Reading instructions from: {}", instructions_path.display());
247    let instructions = smol::unblock(move || std::fs::read_to_string(&instructions_path)).await?;
248    Ok(instructions)
249}
250
251pub async fn save_eval_results(exercise_path: &Path, results: Vec<EvalResult>) -> Result<()> {
252    let eval_dir = exercise_path.join("evaluation");
253    fs::create_dir_all(&eval_dir)?;
254
255    let eval_file = eval_dir.join("evals.json");
256
257    println!("Saving evaluation results to: {}", eval_file.display());
258    println!(
259        "Results to save: {} evaluations for exercise path: {}",
260        results.len(),
261        exercise_path.display()
262    );
263
264    // Check file existence before reading/writing
265    if eval_file.exists() {
266        println!("Existing evals.json file found, will update it");
267    } else {
268        println!("No existing evals.json file found, will create new one");
269    }
270
271    // Structure to organize evaluations by test name and timestamp
272    let mut eval_data: serde_json::Value = if eval_file.exists() {
273        let content = fs::read_to_string(&eval_file)?;
274        serde_json::from_str(&content).unwrap_or_else(|_| serde_json::json!({}))
275    } else {
276        serde_json::json!({})
277    };
278
279    // Get current timestamp for this batch of results
280    let timestamp = SystemTime::now()
281        .duration_since(SystemTime::UNIX_EPOCH)?
282        .as_millis()
283        .to_string();
284
285    // Group the new results by test name (exercise name)
286    for result in results {
287        let exercise_name = &result.exercise_name;
288
289        println!("Adding result: exercise={}", exercise_name);
290
291        // Ensure the exercise entry exists
292        if eval_data.get(exercise_name).is_none() {
293            eval_data[exercise_name] = serde_json::json!({});
294        }
295
296        // Ensure the timestamp entry exists as an object
297        if eval_data[exercise_name].get(&timestamp).is_none() {
298            eval_data[exercise_name][&timestamp] = serde_json::json!({});
299        }
300
301        // Add this result under the timestamp with template name as key
302        eval_data[exercise_name][&timestamp] = serde_json::to_value(&result)?;
303    }
304
305    // Write back to file with pretty formatting
306    let json_content = serde_json::to_string_pretty(&eval_data)?;
307    match fs::write(&eval_file, json_content) {
308        Ok(_) => println!("✓ Successfully saved results to {}", eval_file.display()),
309        Err(e) => println!("✗ Failed to write results file: {}", e),
310    }
311
312    Ok(())
313}
314
315pub async fn run_exercise_eval(
316    exercise_path: PathBuf,
317    model: Arc<dyn LanguageModel>,
318    app_state: Arc<HeadlessAppState>,
319    base_sha: String,
320    _framework_path: PathBuf,
321    cx: gpui::AsyncApp,
322) -> Result<EvalResult> {
323    let exercise_name = get_exercise_name(&exercise_path);
324    let language = get_exercise_language(&exercise_path)?;
325    let mut instructions = read_instructions(&exercise_path).await?;
326    instructions.push_str(&format!(
327        "\n\nWhen writing the code for this prompt, use {} to achieve the goal.",
328        language
329    ));
330
331    println!("Running evaluation for exercise: {}", exercise_name);
332
333    // Create temporary directory with exercise files
334    let temp_dir = setup_temp_repo(&exercise_path, &base_sha).await?;
335    let temp_path = temp_dir.path().to_path_buf();
336
337    let local_commit_sha = run_git(&temp_path, &["rev-parse", "HEAD"]).await?;
338
339    let start_time = SystemTime::now();
340
341    // Create a basic eval struct to work with the existing system
342    let eval = Eval {
343        repo_path: temp_path.clone(),
344        eval_setup: EvalSetup {
345            url: format!("file://{}", temp_path.display()),
346            base_sha: local_commit_sha, // Use the local commit SHA instead of the framework base SHA
347        },
348        user_prompt: instructions.clone(),
349    };
350
351    // Run the evaluation
352    let eval_output = cx
353        .update(|cx| eval.run(app_state.clone(), model.clone(), cx))?
354        .await?;
355
356    // Get diff from git
357    let diff = eval_output.diff.clone();
358
359    let elapsed_time = start_time.elapsed()?;
360
361    // Calculate total tokens as the sum of input and output tokens
362    let input_tokens = eval_output.token_usage.input_tokens;
363    let output_tokens = eval_output.token_usage.output_tokens;
364    let tool_use_counts = eval_output.tool_use_counts.values().sum::<u32>();
365    let total_tokens = input_tokens + output_tokens;
366
367    // Save results to evaluation directory
368    let result = EvalResult {
369        exercise_name: exercise_name.clone(),
370        diff,
371        assistant_response: eval_output.last_message.clone(),
372        elapsed_time_ms: elapsed_time.as_millis(),
373        timestamp: SystemTime::now()
374            .duration_since(SystemTime::UNIX_EPOCH)?
375            .as_millis(),
376        // Convert u32 token counts to usize
377        input_tokens: input_tokens.try_into().unwrap(),
378        output_tokens: output_tokens.try_into().unwrap(),
379        total_tokens: total_tokens.try_into().unwrap(),
380        tool_use_counts: tool_use_counts.try_into().unwrap(),
381    };
382
383    Ok(result)
384}