diff --git a/crates/edit_prediction_metrics/src/prediction_score.rs b/crates/edit_prediction_metrics/src/prediction_score.rs index 55c1d828762dd0d1fd9933ceb58445c758d86604..942ce3c9d1a47720486dc2f629cda8c8f5e077fb 100644 --- a/crates/edit_prediction_metrics/src/prediction_score.rs +++ b/crates/edit_prediction_metrics/src/prediction_score.rs @@ -218,7 +218,9 @@ pub fn score_prediction(input: PredictionScoringInput<'_>) -> PredictionScore { for expected in input.expected_patches { let delta_chr_f_metrics = delta_chr_f(input.original_text, &expected.text, &actual_text); - if delta_chr_f_metrics.score > best_delta_chr_f_metrics.score { + if best_expected_text.is_none() + || delta_chr_f_metrics.score > best_delta_chr_f_metrics.score + { best_delta_chr_f_metrics = delta_chr_f_metrics; best_expected_cursor = expected.cursor_editable_region_offset; best_expected_text = Some(expected.text.as_str()); @@ -317,3 +319,33 @@ fn compute_cursor_metrics( (Some(_), None) | (None, Some(_)) => (None, Some(false)), } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kept_rate_is_computed_when_best_delta_chr_f_score_is_zero() { + let original_text = ""; + let actual_patch = "--- a/file.txt\n+++ b/file.txt\n@@ -0,0 +1 @@\n+bbbbbb\n"; + let expected_patch = "--- a/file.txt\n+++ b/file.txt\n@@ -0,0 +1 @@\n+cccccc\n"; + let expected_patches = [PreparedExpectedPatch { + patch: expected_patch.to_string(), + text: "cccccc".to_string(), + cursor_editable_region_offset: None, + }]; + + let score = score_prediction(PredictionScoringInput { + original_text, + expected_patches: &expected_patches, + actual_patch: Some(actual_patch), + actual_cursor: None, + reversal_context: None, + cumulative_logprob: None, + avg_logprob: None, + }); + + assert_eq!(score.delta_chr_f, 0.0); + assert_eq!(score.kept_rate, Some(0.0)); + } +}