score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9};
 10use anyhow::Context as _;
 11use edit_prediction::udiff::apply_diff_to_string;
 12use gpui::AsyncApp;
 13use std::sync::Arc;
 14
 15pub async fn run_scoring(
 16    example: &mut Example,
 17    args: &PredictArgs,
 18    app_state: Arc<EpAppState>,
 19    example_progress: &ExampleProgress,
 20    cx: AsyncApp,
 21) -> anyhow::Result<()> {
 22    run_prediction(example, args, app_state, example_progress, cx).await?;
 23
 24    let progress = example_progress.start(Step::Score);
 25
 26    progress.set_substatus("applying patches");
 27    let original_text = &example
 28        .prompt_inputs
 29        .as_ref()
 30        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 31        .content;
 32    let expected_texts: Vec<String> = example
 33        .spec
 34        .expected_patches
 35        .iter()
 36        .map(|patch| {
 37            apply_diff_to_string(patch, original_text)
 38                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 39        })
 40        .collect::<Result<Vec<_>, _>>()?;
 41
 42    let zero_scores = ExampleScore {
 43        delta_chr_f: 0.0,
 44        braces_disbalance: 0,
 45    };
 46
 47    progress.set_substatus("computing metrics");
 48    let mut scores = vec![];
 49    for prediction in &example.predictions {
 50        let actual_patch = prediction.actual_patch.clone().or_else(|| {
 51            parse_prediction_output(example, &prediction.actual_output, prediction.provider).ok()
 52        });
 53
 54        let Some(actual_patch) = actual_patch else {
 55            scores.push(zero_scores.clone());
 56            continue;
 57        };
 58
 59        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 60            Ok(text) => text,
 61            Err(_) => {
 62                scores.push(zero_scores.clone());
 63                continue;
 64            }
 65        };
 66        let best_delta_chr_f = expected_texts
 67            .iter()
 68            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 69            .fold(0.0, f32::max);
 70
 71        let disbalance_before = metrics::braces_disbalance(&original_text);
 72        let disbalance_after = metrics::braces_disbalance(&actual_text);
 73        let braces_disbalance = disbalance_after.saturating_sub(disbalance_before);
 74        if braces_disbalance > 0 {
 75            std::fs::write(
 76                "/tmp/unbalanced-count.before",
 77                disbalance_before.to_string(),
 78            )
 79            .ok();
 80            std::fs::write("/tmp/unbalanced-count.after", disbalance_after.to_string()).ok();
 81            std::fs::write("/tmp/unbalanced-text.before", &original_text).ok();
 82            std::fs::write("/tmp/unbalanced-text.after", &actual_text).ok();
 83        }
 84
 85        scores.push(ExampleScore {
 86            delta_chr_f: best_delta_chr_f,
 87            braces_disbalance,
 88        });
 89    }
 90
 91    example.score = scores;
 92    Ok(())
 93}
 94
 95pub fn print_report(examples: &[Example]) {
 96    eprintln!(
 97        "──────────────────────────────────────────────────────────────────────────────────────"
 98    );
 99    eprintln!(
100        "{:<50} {:>14} {:>10}",
101        "Example name", "BracesDisbalance", "DeltaChrF"
102    );
103    eprintln!(
104        "──────────────────────────────────────────────────────────────────────────────────────"
105    );
106
107    let mut all_delta_chr_f_scores = Vec::new();
108    let mut braces_disbalance_sum: usize = 0;
109    let mut total_scores: usize = 0;
110
111    for example in examples {
112        for score in example.score.iter() {
113            eprintln!(
114                "{:<50} {:>14} {:>9.2}",
115                truncate_name(&example.spec.name, 50),
116                score.braces_disbalance,
117                score.delta_chr_f
118            );
119
120            all_delta_chr_f_scores.push(score.delta_chr_f);
121            total_scores += 1;
122            braces_disbalance_sum += score.braces_disbalance;
123        }
124    }
125
126    eprintln!(
127        "──────────────────────────────────────────────────────────────────────────────────────"
128    );
129
130    if !all_delta_chr_f_scores.is_empty() {
131        let avg_delta_chr_f: f32 =
132            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
133        let braces_disbalance_avg: f32 = braces_disbalance_sum as f32 / total_scores as f32;
134        let braces_disbalance_display = format!("{:.2}", braces_disbalance_avg);
135
136        eprintln!(
137            "{:<50} {:>14} {:>9.2}",
138            "AVERAGE", braces_disbalance_display, avg_delta_chr_f
139        );
140        eprintln!(
141            "──────────────────────────────────────────────────────────────────────────────────────"
142        );
143    }
144
145    eprintln!("\n");
146}
147
148fn truncate_name(name: &str, max_len: usize) -> String {
149    if name.len() <= max_len {
150        name.to_string()
151    } else {
152        format!("{}...", &name[..max_len - 3])
153    }
154}