score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    parse_output::parse_prediction_output,
  7    predict::run_prediction,
  8    progress::{ExampleProgress, Step},
  9};
 10use anyhow::Context as _;
 11use edit_prediction::udiff::apply_diff_to_string;
 12use gpui::AsyncApp;
 13use std::sync::Arc;
 14
 15pub async fn run_scoring(
 16    example: &mut Example,
 17    args: &PredictArgs,
 18    app_state: Arc<EpAppState>,
 19    example_progress: &ExampleProgress,
 20    cx: AsyncApp,
 21) -> anyhow::Result<()> {
 22    run_prediction(example, args, app_state, example_progress, cx).await?;
 23
 24    let progress = example_progress.start(Step::Score);
 25
 26    progress.set_substatus("applying patches");
 27    let original_text = &example
 28        .prompt_inputs
 29        .as_ref()
 30        .context("prompt_inputs is required for scoring - run prediction first or ensure JSON includes prompt_inputs")?
 31        .content;
 32    let expected_texts: Vec<String> = example
 33        .spec
 34        .expected_patches
 35        .iter()
 36        .map(|patch| {
 37            apply_diff_to_string(patch, original_text)
 38                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 39        })
 40        .collect::<Result<Vec<_>, _>>()?;
 41
 42    progress.set_substatus("computing metrics");
 43    let mut scores = vec![];
 44    for prediction in &example.predictions {
 45        let actual_patch = match &prediction.actual_patch {
 46            Some(patch) => patch.clone(),
 47            None => {
 48                if prediction.actual_output.is_empty() {
 49                    scores.push(ExampleScore { delta_chr_f: 0.0 });
 50                    continue;
 51                }
 52                match parse_prediction_output(
 53                    example,
 54                    &prediction.actual_output,
 55                    prediction.provider,
 56                ) {
 57                    Ok(patch) => patch,
 58                    Err(_) => {
 59                        scores.push(ExampleScore { delta_chr_f: 0.0 });
 60                        continue;
 61                    }
 62                }
 63            }
 64        };
 65        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
 66            Ok(text) => text,
 67            Err(_) => {
 68                scores.push(ExampleScore { delta_chr_f: 0.0 });
 69                continue;
 70            }
 71        };
 72        let best_delta_chr_f = expected_texts
 73            .iter()
 74            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 75            .fold(0.0, f32::max);
 76        scores.push(ExampleScore {
 77            delta_chr_f: best_delta_chr_f,
 78        });
 79    }
 80
 81    example.score = scores;
 82    Ok(())
 83}
 84
 85pub fn print_report(examples: &[Example]) {
 86    eprintln!(
 87        "──────────────────────────────────────────────────────────────────────────────────────"
 88    );
 89    eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
 90    eprintln!(
 91        "──────────────────────────────────────────────────────────────────────────────────────"
 92    );
 93
 94    let mut all_delta_chr_f_scores = Vec::new();
 95
 96    for example in examples {
 97        for score in example.score.iter() {
 98            eprintln!(
 99                "{:<50} {:>9.2}",
100                truncate_name(&example.spec.name, 50),
101                score.delta_chr_f
102            );
103
104            all_delta_chr_f_scores.push(score.delta_chr_f);
105        }
106    }
107
108    eprintln!(
109        "──────────────────────────────────────────────────────────────────────────────────────"
110    );
111
112    if !all_delta_chr_f_scores.is_empty() {
113        let avg_delta_chr_f: f32 =
114            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
115
116        eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
117        eprintln!(
118            "──────────────────────────────────────────────────────────────────────────────────────"
119        );
120    }
121
122    eprintln!("\n");
123}
124
125fn truncate_name(name: &str, max_len: usize) -> String {
126    if name.len() <= max_len {
127        name.to_string()
128    } else {
129        format!("{}...", &name[..max_len - 3])
130    }
131}