score.rs

  1use crate::{
  2    PredictArgs,
  3    example::{Example, ExampleScore},
  4    headless::EpAppState,
  5    metrics,
  6    predict::run_prediction,
  7    progress::{Progress, Step},
  8};
  9use anyhow::Context as _;
 10use edit_prediction::udiff::apply_diff_to_string;
 11use gpui::AsyncApp;
 12use std::sync::Arc;
 13
 14pub async fn run_scoring(
 15    example: &mut Example,
 16    args: &PredictArgs,
 17    app_state: Arc<EpAppState>,
 18    cx: AsyncApp,
 19) -> anyhow::Result<()> {
 20    run_prediction(
 21        example,
 22        Some(args.provider),
 23        args.repetitions,
 24        app_state,
 25        cx,
 26    )
 27    .await?;
 28
 29    let _progress = Progress::global().start(Step::Score, &example.spec.name);
 30
 31    let original_text = &example.buffer.as_ref().unwrap().content;
 32    let expected_texts: Vec<String> = example
 33        .spec
 34        .expected_patches
 35        .iter()
 36        .map(|patch| {
 37            apply_diff_to_string(original_text, patch)
 38                .with_context(|| format!("Expected patch did not apply for {}", example.spec.name))
 39        })
 40        .collect::<Result<Vec<_>, _>>()?;
 41
 42    let mut scores = vec![];
 43    for prediction in &example.predictions {
 44        let actual_text = match apply_diff_to_string(original_text, &prediction.actual_patch) {
 45            Ok(text) => text,
 46            Err(_) => {
 47                scores.push(ExampleScore { delta_chr_f: 0.0 });
 48                continue;
 49            }
 50        };
 51        let best_delta_chr_f = expected_texts
 52            .iter()
 53            .map(|expected| metrics::delta_chr_f(original_text, expected, &actual_text) as f32)
 54            .fold(0.0, f32::max);
 55        scores.push(ExampleScore {
 56            delta_chr_f: best_delta_chr_f,
 57        });
 58    }
 59
 60    example.score = scores;
 61    Ok(())
 62}
 63
 64pub fn print_report(examples: &[Example]) {
 65    eprintln!(
 66        "──────────────────────────────────────────────────────────────────────────────────────"
 67    );
 68    eprintln!("{:<50} {:>10}", "Example name", "DeltaChrF");
 69    eprintln!(
 70        "──────────────────────────────────────────────────────────────────────────────────────"
 71    );
 72
 73    let mut all_delta_chr_f_scores = Vec::new();
 74
 75    for example in examples {
 76        for score in example.score.iter() {
 77            eprintln!(
 78                "{:<50} {:>9.2}",
 79                truncate_name(&example.spec.name, 50),
 80                score.delta_chr_f
 81            );
 82
 83            all_delta_chr_f_scores.push(score.delta_chr_f);
 84        }
 85    }
 86
 87    eprintln!(
 88        "──────────────────────────────────────────────────────────────────────────────────────"
 89    );
 90
 91    if !all_delta_chr_f_scores.is_empty() {
 92        let avg_delta_chr_f: f32 =
 93            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
 94
 95        eprintln!("{:<50} {:>9.2}", "AVERAGE", avg_delta_chr_f);
 96        eprintln!(
 97            "──────────────────────────────────────────────────────────────────────────────────────"
 98        );
 99    }
100
101    eprintln!("\n");
102}
103
104fn truncate_name(name: &str, max_len: usize) -> String {
105    if name.len() <= max_len {
106        name.to_string()
107    } else {
108        format!("{}...", &name[..max_len - 3])
109    }
110}