Show total number of characters accepted/rejected

Oleksiy Syvokon created

Change summary

crates/edit_prediction_cli/src/example.rs |  6 +
crates/edit_prediction_cli/src/score.rs   | 91 ++++++++++++++++++++++--
2 files changed, 88 insertions(+), 9 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/example.rs 🔗

@@ -188,6 +188,12 @@ pub struct ExampleScore {
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub recall_rate: Option<f64>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub kept_chars: Option<usize>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub correctly_deleted_chars: Option<usize>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub discarded_chars: Option<usize>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
     pub cumulative_logprob: Option<f64>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub avg_logprob: Option<f64>,

crates/edit_prediction_cli/src/score.rs 🔗

@@ -86,6 +86,9 @@ pub async fn run_scoring(
         deleted_tokens: 0,
         kept_rate: None,
         recall_rate: None,
+        kept_chars: None,
+        correctly_deleted_chars: None,
+        discarded_chars: None,
         cumulative_logprob: None,
         avg_logprob: None,
     };
@@ -188,13 +191,20 @@ pub async fn run_scoring(
             prediction.actual_cursor.as_ref(),
         );
 
-        let (kept_rate, recall_rate) = best_expected_text
-            .map(|reference_text| {
-                let result =
-                    metrics::compute_kept_rate(original_text, &actual_text, reference_text);
-                (Some(result.kept_rate), Some(result.recall_rate))
-            })
-            .unwrap_or((None, None));
+        let (kept_rate, recall_rate, kept_chars, correctly_deleted_chars, discarded_chars) =
+            best_expected_text
+                .map(|reference_text| {
+                    let result =
+                        metrics::compute_kept_rate(original_text, &actual_text, reference_text);
+                    (
+                        Some(result.kept_rate),
+                        Some(result.recall_rate),
+                        Some(result.kept_chars),
+                        Some(result.correctly_deleted_chars),
+                        Some(result.discarded_chars),
+                    )
+                })
+                .unwrap_or((None, None, None, None, None));
 
         scores.push(ExampleScore {
             delta_chr_f: best_delta_chr_f_metrics.score as f32,
@@ -217,6 +227,9 @@ pub async fn run_scoring(
             deleted_tokens: token_changes.deleted_tokens,
             kept_rate,
             recall_rate,
+            kept_chars,
+            correctly_deleted_chars,
+            discarded_chars,
             cumulative_logprob: prediction.cumulative_logprob,
             avg_logprob: prediction.avg_logprob,
         });
@@ -283,6 +296,9 @@ pub fn print_report(examples: &[Example], verbose: bool) {
     let mut isolated_whitespace_count: usize = 0;
     let mut kept_rate_sum: f64 = 0.0;
     let mut kept_rate_count: usize = 0;
+    let mut kept_chars_total: usize = 0;
+    let mut correctly_deleted_chars_total: usize = 0;
+    let mut discarded_chars_total: usize = 0;
     let mut recall_rate_sum: f64 = 0.0;
     let mut recall_rate_count: usize = 0;
     let mut patch_inserted_tokens: Vec<usize> = Vec::new();
@@ -382,6 +398,15 @@ pub fn print_report(examples: &[Example], verbose: bool) {
                 kept_rate_sum += kr;
                 kept_rate_count += 1;
             }
+            if let Some(kept_chars) = score.kept_chars {
+                kept_chars_total += kept_chars;
+            }
+            if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
+                correctly_deleted_chars_total += correctly_deleted_chars;
+            }
+            if let Some(discarded_chars) = score.discarded_chars {
+                discarded_chars_total += discarded_chars;
+            }
             if let Some(rr) = score.recall_rate {
                 recall_rate_sum += rr;
                 recall_rate_count += 1;
@@ -520,9 +545,12 @@ pub fn print_report(examples: &[Example], verbose: bool) {
         if kept_rate_count > 0 {
             let avg_kept_rate = kept_rate_sum / kept_rate_count as f64;
             println!(
-                "Kept rate: {:.1}% avg ({} evaluated)",
+                "Kept rate: {:.1}% avg ({} evaluated, kept chars: {}, correctly deleted chars: {}, discarded chars: {})",
                 avg_kept_rate * 100.0,
-                kept_rate_count
+                kept_rate_count,
+                kept_chars_total,
+                correctly_deleted_chars_total,
+                discarded_chars_total
             );
         }
         if recall_rate_count > 0 {
@@ -640,6 +668,12 @@ pub struct SummaryJson {
     pub avg_kept_rate: Option<f64>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub avg_recall_rate: Option<f64>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub total_kept_chars: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub total_correctly_deleted_chars: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub total_discarded_chars: Option<usize>,
 }
 
 pub fn compute_summary(examples: &[Example]) -> SummaryJson {
@@ -667,6 +701,12 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
     let mut isolated_whitespace_count: usize = 0;
     let mut kept_rate_sum: f64 = 0.0;
     let mut kept_rate_count: usize = 0;
+    let mut kept_chars_total: usize = 0;
+    let mut kept_chars_count: usize = 0;
+    let mut correctly_deleted_chars_total: usize = 0;
+    let mut correctly_deleted_chars_count: usize = 0;
+    let mut discarded_chars_total: usize = 0;
+    let mut discarded_chars_count: usize = 0;
     let mut recall_rate_sum: f64 = 0.0;
     let mut recall_rate_count: usize = 0;
 
@@ -714,6 +754,18 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
                 kept_rate_sum += kr;
                 kept_rate_count += 1;
             }
+            if let Some(kept_chars) = score.kept_chars {
+                kept_chars_total += kept_chars;
+                kept_chars_count += 1;
+            }
+            if let Some(correctly_deleted_chars) = score.correctly_deleted_chars {
+                correctly_deleted_chars_total += correctly_deleted_chars;
+                correctly_deleted_chars_count += 1;
+            }
+            if let Some(discarded_chars) = score.discarded_chars {
+                discarded_chars_total += discarded_chars;
+                discarded_chars_count += 1;
+            }
             if let Some(rr) = score.recall_rate {
                 recall_rate_sum += rr;
                 recall_rate_count += 1;
@@ -805,6 +857,24 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
         None
     };
 
+    let total_kept_chars = if kept_chars_count > 0 {
+        Some(kept_chars_total)
+    } else {
+        None
+    };
+
+    let total_correctly_deleted_chars = if correctly_deleted_chars_count > 0 {
+        Some(correctly_deleted_chars_total)
+    } else {
+        None
+    };
+
+    let total_discarded_chars = if discarded_chars_count > 0 {
+        Some(discarded_chars_total)
+    } else {
+        None
+    };
+
     SummaryJson {
         total_examples: total_scores,
         avg_delta_chr_f,
@@ -839,6 +909,9 @@ pub fn compute_summary(examples: &[Example]) -> SummaryJson {
         isolated_whitespace_rate,
         avg_kept_rate,
         avg_recall_rate,
+        total_kept_chars,
+        total_correctly_deleted_chars,
+        total_discarded_chars,
     }
 }