diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 1e044b0dae353498b67ffa917d89e2945f4f7787..a23010fa21c9593eda058af23498f9ae19577235 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -188,6 +188,12 @@ pub struct ExampleScore { #[serde(default, skip_serializing_if = "Option::is_none")] pub recall_rate: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub kept_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub correctly_deleted_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub discarded_chars: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub cumulative_logprob: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub avg_logprob: Option, diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs index f30cf7d106f737f1e479fdac38adc10e4effcea2..38329c8c3329fa3f26f5795b6a9bdcd02997b59f 100644 --- a/crates/edit_prediction_cli/src/score.rs +++ b/crates/edit_prediction_cli/src/score.rs @@ -86,6 +86,9 @@ pub async fn run_scoring( deleted_tokens: 0, kept_rate: None, recall_rate: None, + kept_chars: None, + correctly_deleted_chars: None, + discarded_chars: None, cumulative_logprob: None, avg_logprob: None, }; @@ -188,13 +191,20 @@ pub async fn run_scoring( prediction.actual_cursor.as_ref(), ); - let (kept_rate, recall_rate) = best_expected_text - .map(|reference_text| { - let result = - metrics::compute_kept_rate(original_text, &actual_text, reference_text); - (Some(result.kept_rate), Some(result.recall_rate)) - }) - .unwrap_or((None, None)); + let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) = + best_expected_text + .map(|reference_text| { + let result = + metrics::compute_kept_rate(original_text, &actual_text, reference_text); + ( + Some(result.kept_rate), + Some(result.recall_rate), + Some(result.kept_chars), + Some(result.correctly_deleted_chars), + Some(result.discarded_chars), + ) + }) + .unwrap_or((None, None, None, None, None)); scores.push(ExampleScore { delta_chr_f: best_delta_chr_f_metrics.score as f32, @@ -217,6 +227,9 @@ pub async fn run_scoring( deleted_tokens: token_changes.deleted_tokens, kept_rate, recall_rate, + kept_chars, + correctly_deleted_chars, + discarded_chars, cumulative_logprob: prediction.cumulative_logprob, avg_logprob: prediction.avg_logprob, }); @@ -283,6 +296,9 @@ pub fn print_report(examples: &[Example], verbose: bool) { let mut isolated_whitespace_count: usize = 0; let mut kept_rate_sum: f64 = 0.0; let mut kept_rate_count: usize = 0; + let mut kept_chars_total: usize = 0; + let mut correctly_deleted_chars_total: usize = 0; + let mut discarded_chars_total: usize = 0; let mut recall_rate_sum: f64 = 0.0; let mut recall_rate_count: usize = 0; let mut patch_inserted_tokens: Vec = Vec::new(); @@ -382,6 +398,15 @@ pub fn print_report(examples: &[Example], verbose: bool) { kept_rate_sum += kr; kept_rate_count += 1; } + if let Some(kept_chars) = score.kept_chars { + kept_chars_total += kept_chars; + } + if let Some(correctly_deleted_chars) = score.correctly_deleted_chars { + correctly_deleted_chars_total += correctly_deleted_chars; + } + if let Some(discarded_chars) = score.discarded_chars { + discarded_chars_total += discarded_chars; + } if let Some(rr) = score.recall_rate { recall_rate_sum += rr; recall_rate_count += 1; @@ -520,9 +545,12 @@ pub fn print_report(examples: &[Example], verbose: bool) { if kept_rate_count > 0 { let avg_kept_rate = kept_rate_sum / kept_rate_count as f64; println!( - "Kept rate: {:.1}% avg ({} evaluated)", + "Kept rate: {:.1}% avg ({} evaluated, kept chars: {}, correctly deleted chars: {}, discarded chars: {})", avg_kept_rate * 100.0, - kept_rate_count + kept_rate_count, + kept_chars_total, + correctly_deleted_chars_total, + discarded_chars_total ); } if recall_rate_count > 0 { @@ -640,6 +668,12 @@ pub struct SummaryJson { pub avg_kept_rate: Option, #[serde(skip_serializing_if = "Option::is_none")] pub avg_recall_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_kept_chars: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_correctly_deleted_chars: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_discarded_chars: Option, } pub fn compute_summary(examples: &[Example]) -> SummaryJson { @@ -667,6 +701,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { let mut isolated_whitespace_count: usize = 0; let mut kept_rate_sum: f64 = 0.0; let mut kept_rate_count: usize = 0; + let mut kept_chars_total: usize = 0; + let mut kept_chars_count: usize = 0; + let mut correctly_deleted_chars_total: usize = 0; + let mut correctly_deleted_chars_count: usize = 0; + let mut discarded_chars_total: usize = 0; + let mut discarded_chars_count: usize = 0; let mut recall_rate_sum: f64 = 0.0; let mut recall_rate_count: usize = 0; @@ -714,6 +754,18 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { kept_rate_sum += kr; kept_rate_count += 1; } + if let Some(kept_chars) = score.kept_chars { + kept_chars_total += kept_chars; + kept_chars_count += 1; + } + if let Some(correctly_deleted_chars) = score.correctly_deleted_chars { + correctly_deleted_chars_total += correctly_deleted_chars; + correctly_deleted_chars_count += 1; + } + if let Some(discarded_chars) = score.discarded_chars { + discarded_chars_total += discarded_chars; + discarded_chars_count += 1; + } if let Some(rr) = score.recall_rate { recall_rate_sum += rr; recall_rate_count += 1; @@ -805,6 +857,24 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { None }; + let total_kept_chars = if kept_chars_count > 0 { + Some(kept_chars_total) + } else { + None + }; + + let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 { + Some(correctly_deleted_chars_total) + } else { + None + }; + + let total_discarded_chars = if discarded_chars_count > 0 { + Some(discarded_chars_total) + } else { + None + }; + SummaryJson { total_examples: total_scores, avg_delta_chr_f, @@ -839,6 +909,9 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson { isolated_whitespace_rate, avg_kept_rate, avg_recall_rate, + total_kept_chars, + total_correctly_deleted_chars, + total_discarded_chars, } }