ep: Change kept_rate definition to a more intuitive one (#54306)

Oleksiy Syvokon created

This change contains a number of fixes to make kept_rate more intuitive.
It also adds a CLI utility to print debug info on how the metric is
computed.

Release Notes:

- N/A

Change summary

Cargo.lock                                                    |   1 
crates/edit_prediction_metrics/Cargo.toml                     |   1 
crates/edit_prediction_metrics/src/edit_prediction_metrics.rs |   4 
crates/edit_prediction_metrics/src/kept_rate.rs               | 283 +
crates/edit_prediction_metrics/src/main.rs                    | 710 +++++
crates/edit_prediction_metrics/src/patch_metrics.rs           |  29 
crates/edit_prediction_metrics/src/tokenize.rs                | 206 +
typos.toml                                                    |   1 
8 files changed, 1,145 insertions(+), 90 deletions(-)

Detailed changes

Cargo.lock ๐Ÿ”—

@@ -5338,6 +5338,7 @@ dependencies = [
  "language",
  "pretty_assertions",
  "serde",
+ "serde_json",
  "similar",
  "tree-sitter",
  "zeta_prompt",

crates/edit_prediction_metrics/Cargo.toml ๐Ÿ”—

@@ -14,6 +14,7 @@ path = "src/edit_prediction_metrics.rs"
 [dependencies]
 language.workspace = true
 serde.workspace = true
+serde_json = "1.0"
 similar = "2.7.0"
 tree-sitter.workspace = true
 zeta_prompt.workspace = true

crates/edit_prediction_metrics/src/edit_prediction_metrics.rs ๐Ÿ”—

@@ -4,9 +4,10 @@ mod reversal;
 mod tokenize;
 mod tree_sitter;
 
+pub use kept_rate::AnnotatedToken;
 pub use kept_rate::KeptRateResult;
-#[cfg(test)]
 pub use kept_rate::TokenAnnotation;
+pub use kept_rate::annotate_kept_rate_tokens;
 pub use kept_rate::compute_kept_rate;
 pub use patch_metrics::ClassificationMetrics;
 pub use patch_metrics::Counts;
@@ -20,5 +21,6 @@ pub use patch_metrics::exact_lines_match;
 pub use patch_metrics::extract_changed_lines_from_diff;
 pub use patch_metrics::has_isolated_whitespace_changes;
 pub use patch_metrics::is_editable_region_correct;
+pub use patch_metrics::reconstruct_texts_from_diff;
 pub use reversal::compute_prediction_reversal_ratio_from_history;
 pub use tree_sitter::count_tree_sitter_errors;

crates/edit_prediction_metrics/src/kept_rate.rs ๐Ÿ”—

@@ -3,14 +3,20 @@ use serde::Serialize;
 
 const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
 
-#[cfg(test)]
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
+#[serde(rename_all = "snake_case")]
 pub enum TokenAnnotation {
     Context,
     Kept,
     Discarded,
 }
 
+#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
+pub struct AnnotatedToken {
+    pub token: String,
+    pub annotation: TokenAnnotation,
+}
+
 #[allow(dead_code)]
 #[derive(Debug, Clone, Serialize)]
 pub struct KeptRateResult {
@@ -40,8 +46,7 @@ pub struct KeptRateResult {
     /// This includes both kept newly introduced characters and correctly
     /// deleted base characters.
     pub recall_rate: f64,
-    /// Per-token classification for candidate tokens used by tests.
-    #[cfg(test)]
+    /// Per-token classification for candidate tokens.
     pub token_annotations: Vec<TokenAnnotation>,
 }
 
@@ -51,9 +56,9 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
 
 /// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
 /// while sharing a single DP table construction.
-fn fill_lcs_keep_masks(
-    a: &[&str],
-    b: &[&str],
+fn fill_lcs_keep_masks<T: Eq>(
+    a: &[T],
+    b: &[T],
     mut keep_a: Option<&mut [bool]>,
     mut keep_b: Option<&mut [bool]>,
 ) {
@@ -124,10 +129,10 @@ fn fill_lcs_keep_masks(
     let mut dp = vec![0u32; row_count * column_count];
 
     for i in 1..row_count {
-        let token_a = a_mid[i - 1];
+        let token_a = &a_mid[i - 1];
         for j in 1..column_count {
             let index = dp_index(column_count, i, j);
-            if token_a == b_mid[j - 1] {
+            if token_a == &b_mid[j - 1] {
                 dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
             } else {
                 let up = dp[dp_index(column_count, i - 1, j)];
@@ -180,41 +185,91 @@ fn fill_lcs_keep_masks(
     }
 }
 
-fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
+fn lcs_keep_mask<T: Eq>(a: &[T], b: &[T]) -> Vec<bool> {
     let mut keep_a = vec![false; a.len()];
     fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
     keep_a
 }
 
-fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+fn lcs_keep_masks<T: Eq>(a: &[T], b: &[T]) -> (Vec<bool>, Vec<bool>) {
     let mut keep_a = vec![false; a.len()];
     let mut keep_b = vec![false; b.len()];
     fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
     (keep_a, keep_b)
 }
 
-fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
-    let mut unmasked_tokens = Vec::with_capacity(tokens.len());
+#[derive(Debug, Clone)]
+struct ComparisonUnit {
+    text: String,
+    token_start: usize,
+    token_end: usize,
+}
+
+fn is_identifier_token(token: &str) -> bool {
+    !token.is_empty()
+        && token
+            .chars()
+            .all(|character| character.is_alphanumeric() || character == '_')
+}
+
+fn build_comparison_units(tokens: &[&str]) -> Vec<ComparisonUnit> {
+    let mut units = Vec::new();
+    let mut index = 0;
+
+    while index < tokens.len() {
+        let token_start = index;
+
+        if is_identifier_token(tokens[index]) {
+            let mut text = String::new();
+
+            while index < tokens.len() && is_identifier_token(tokens[index]) {
+                text.push_str(tokens[index]);
+                index += 1;
+            }
+
+            units.push(ComparisonUnit {
+                text,
+                token_start,
+                token_end: index,
+            });
+        } else {
+            units.push(ComparisonUnit {
+                text: tokens[index].to_string(),
+                token_start,
+                token_end: index + 1,
+            });
+            index += 1;
+        }
+    }
+
+    units
+}
+
+fn analyze_masked_units<'a>(
+    units: &'a [ComparisonUnit],
+    mask: &[bool],
+) -> (Vec<&'a str>, usize, usize) {
+    let mut unmasked_units = Vec::with_capacity(units.len());
     let mut unmasked_chars = 0;
     let mut masked_chars = 0;
 
-    for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
+    for (unit, &is_masked) in units.iter().zip(mask.iter()) {
         if is_masked {
-            masked_chars += token.len();
+            masked_chars += unit.text.len();
         } else {
-            unmasked_tokens.push(token);
-            unmasked_chars += token.len();
+            unmasked_units.push(unit.text.as_str());
+            unmasked_chars += unit.text.len();
         }
     }
 
-    (unmasked_tokens, unmasked_chars, masked_chars)
+    (unmasked_units, unmasked_chars, masked_chars)
 }
 
-fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
-    tokens
+fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize {
+    units
         .iter()
         .zip(mask.iter())
-        .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
+        .filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len()))
         .sum()
 }
 
@@ -239,7 +294,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
             context_chars,
             kept_rate: 1.0,
             recall_rate: 1.0,
-            #[cfg(test)]
             token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
         };
     }
@@ -258,7 +312,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
             context_chars: 0,
             kept_rate: 0.0,
             recall_rate: 0.0,
-            #[cfg(test)]
             token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
         };
     }
@@ -267,29 +320,29 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
     let candidate_tokens = tokenize(candidate);
     let reference_tokens = tokenize(reference);
 
-    let (candidate_base_mask, base_candidate_mask) =
-        lcs_keep_masks(&candidate_tokens, &base_tokens);
-    let (candidate_reference_mask, reference_candidate_mask) =
-        lcs_keep_masks(&candidate_tokens, &reference_tokens);
-    let context_mask: Vec<bool> = candidate_base_mask
+    let candidate_units = build_comparison_units(&candidate_tokens);
+    let base_units = build_comparison_units(&base_tokens);
+    let reference_units = build_comparison_units(&reference_tokens);
+
+    let candidate_unit_texts: Vec<&str> = candidate_units
         .iter()
-        .zip(candidate_reference_mask.iter())
-        .map(|(&in_base, &in_reference)| in_base && in_reference)
+        .map(|unit| unit.text.as_str())
+        .collect();
+    let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect();
+    let reference_unit_texts: Vec<&str> = reference_units
+        .iter()
+        .map(|unit| unit.text.as_str())
         .collect();
 
+    let (candidate_base_mask, base_candidate_mask) =
+        lcs_keep_masks(&candidate_unit_texts, &base_unit_texts);
     let (stripped_candidate, candidate_new_chars, context_chars) =
-        analyze_masked_tokens(&candidate_tokens, &context_mask);
+        analyze_masked_units(&candidate_units, &candidate_base_mask);
 
     let (reference_base_mask, base_reference_mask) =
-        lcs_keep_masks(&reference_tokens, &base_tokens);
-    let reference_context_mask: Vec<bool> = reference_base_mask
-        .iter()
-        .zip(reference_candidate_mask.iter())
-        .map(|(&in_base, &in_candidate)| in_base && in_candidate)
-        .collect();
-
+        lcs_keep_masks(&reference_unit_texts, &base_unit_texts);
     let (stripped_reference, reference_new_chars, _) =
-        analyze_masked_tokens(&reference_tokens, &reference_context_mask);
+        analyze_masked_units(&reference_units, &reference_base_mask);
 
     let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
 
@@ -299,13 +352,13 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
         .filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
         .sum();
 
-    let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
-    let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
-    let correctly_deleted_chars: usize = base_tokens
+    let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask);
+    let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask);
+    let correctly_deleted_chars: usize = base_units
         .iter()
         .zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
-        .filter_map(|(&token, (&in_candidate, &in_reference))| {
-            (!in_candidate && !in_reference).then_some(token.len())
+        .filter_map(|(unit, (&in_candidate, &in_reference))| {
+            (!in_candidate && !in_reference).then_some(unit.text.len())
         })
         .sum();
 
@@ -326,24 +379,28 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
         matched_edit_chars as f64 / reference_edit_chars as f64
     };
 
-    #[cfg(test)]
     let token_annotations = {
-        let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
+        let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()];
         let mut new_index = 0;
-        for (token_index, _token) in candidate_tokens.iter().enumerate() {
-            if context_mask[token_index] {
-                token_annotations.push(TokenAnnotation::Context);
+
+        for (unit_index, unit) in candidate_units.iter().enumerate() {
+            let annotation = if candidate_base_mask[unit_index] {
+                TokenAnnotation::Context
             } else {
                 let annotation = if keep_mask[new_index] {
                     TokenAnnotation::Kept
                 } else {
                     TokenAnnotation::Discarded
                 };
-                #[cfg(test)]
-                token_annotations.push(annotation);
                 new_index += 1;
+                annotation
+            };
+
+            for token_index in unit.token_start..unit.token_end {
+                token_annotations[token_index] = annotation;
             }
         }
+
         token_annotations
     };
 
@@ -358,14 +415,30 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
         context_chars,
         kept_rate,
         recall_rate,
-        #[cfg(test)]
         token_annotations,
     }
 }
 
+pub fn annotate_kept_rate_tokens(
+    base: &str,
+    candidate: &str,
+    reference: &str,
+) -> Vec<AnnotatedToken> {
+    let result = compute_kept_rate(base, candidate, reference);
+    tokenize(candidate)
+        .into_iter()
+        .zip(result.token_annotations)
+        .map(|(token, annotation)| AnnotatedToken {
+            token: token.to_string(),
+            annotation,
+        })
+        .collect()
+}
+
 #[cfg(test)]
 mod test_kept_rate {
     use super::*;
+    use indoc::indoc;
 
     #[test]
     fn test_lcs_keep_masks() {
@@ -439,16 +512,24 @@ mod test_kept_rate {
 
     #[test]
     fn test_missing_deletion() {
-        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
-        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\neprintln!(\"\");\n";
-        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
+        let base = indoc! {"
+            fn example() {
+                epr
+        "};
+        let candidate = indoc! {r#"
+            fn example() {
+                epr
+            eprintln!("");
+        "#};
+        let reference = indoc! {r#"
+            fn example() {
+            eprintln!("");
+        "#};
+
         let result = compute_kept_rate(base, candidate, reference);
-        assert!(
-            result.kept_rate < 0.85,
-            "expected kept_rate < 0.85, got {}",
-            result.kept_rate
-        );
-        assert!(result.discarded_chars > 0);
+        assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6);
+        assert_eq!(result.kept_chars, 14);
+        assert_eq!(result.discarded_chars, 1);
     }
 
     #[test]
@@ -472,8 +553,17 @@ mod test_kept_rate {
 
     #[test]
     fn test_bails_for_dirty_final() {
-        let base = "fn example() {\n    work();\n}\n";
-        let candidate = "fn example() {\n    work();\n    predicted();\n}\n";
+        let base = indoc! {"
+            fn example() {
+                work();
+            }
+        "};
+        let candidate = indoc! {"
+            fn example() {
+                work();
+                predicted();
+            }
+        "};
         let reference = format!(
             "fn example() {{\n    work();\n    {}\n}}\n",
             "settled();\n    ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
@@ -488,9 +578,19 @@ mod test_kept_rate {
 
     #[test]
     fn test_eprintln_token_alignment() {
-        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
-        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
-        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
+        let base = indoc! {"
+            fn example() {
+                epr
+        "};
+        let candidate = indoc! {r#"
+            fn example() {
+                eprintln!("hello world!");
+        "#};
+        let reference = indoc! {r#"
+            fn example() {
+                eprintln!("");
+        "#};
+
         let result = compute_kept_rate(base, candidate, reference);
         assert!(result.discarded_chars > 0);
         assert!(result.kept_chars > 0);
@@ -499,6 +599,42 @@ mod test_kept_rate {
         assert_eq!(result.discarded_chars, 12);
     }
 
+    #[test]
+    fn test_kept_rate_treats_unchanged_stale_text_as_context() {
+        let base = indoc! {"
+            a=fomr
+            b=old
+        "};
+        let candidate = indoc! {"
+            a=formula;
+            b=old
+        "};
+        let reference = indoc! {"
+            a=formula;
+            b=new
+        "};
+
+        let result = compute_kept_rate(base, candidate, reference);
+        let candidate_tokens = tokenize(candidate);
+
+        assert_eq!(result.candidate_new_chars, "formula".len() + ";".len());
+        assert_eq!(result.kept_chars, "formula".len() + ";".len());
+        assert_eq!(result.discarded_chars, 0);
+        assert_eq!(result.candidate_deleted_chars, "fomr".len());
+        assert_eq!(result.correctly_deleted_chars, "fomr".len());
+        assert!((result.kept_rate - 1.0).abs() < 1e-6);
+        assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6);
+
+        let old_index = candidate_tokens
+            .iter()
+            .position(|&token| token == "old")
+            .expect("old token not found");
+        assert_eq!(
+            result.token_annotations[old_index],
+            TokenAnnotation::Context
+        );
+    }
+
     #[test]
     fn test_annotations_rename() {
         let base = "    foo(old_name)\n";
@@ -514,7 +650,7 @@ mod test_kept_rate {
         assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
 
         for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
-            if token == "new_name" {
+            if matches!(token, "new" | "_" | "name") {
                 assert_eq!(annotation, TokenAnnotation::Kept);
             } else {
                 assert_eq!(annotation, TokenAnnotation::Context);
@@ -524,9 +660,18 @@ mod test_kept_rate {
 
     #[test]
     fn test_annotations_eprintln_coloring() {
-        let base = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        epr\n";
-        let candidate = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"hello world!\");\n";
-        let reference = "    fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n        eprintln!(\"\");\n";
+        let base = indoc! {"
+            fn example() {
+                epr
+        "};
+        let candidate = indoc! {r#"
+            fn example() {
+                eprintln!("hello world!");
+        "#};
+        let reference = indoc! {r#"
+            fn example() {
+                eprintln!("");
+        "#};
         let result = compute_kept_rate(base, candidate, reference);
         let candidate_tokens = tokenize(candidate);
 

crates/edit_prediction_metrics/src/main.rs ๐Ÿ”—

@@ -0,0 +1,710 @@
+use std::env;
+use std::fmt::Write as _;
+use std::fs;
+use std::path::Path;
+use std::process;
+
+use edit_prediction_metrics::{
+    ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
+    annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
+    delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
+    has_isolated_whitespace_changes, is_editable_region_correct,
+};
+use serde::Deserialize;
+
+fn main() {
+    if let Err(error) = run() {
+        eprintln!("error: {error}");
+        process::exit(1);
+    }
+}
+
+fn run() -> Result<(), String> {
+    let args: Vec<String> = env::args().skip(1).collect();
+    if args.is_empty() {
+        print_usage();
+        return Err("missing arguments".to_string());
+    }
+
+    let input = CliInput::parse(&args)?;
+    let report = match input {
+        CliInput::Files {
+            base_path,
+            expected_patch_path,
+            actual_patch_path,
+        } => {
+            let base = fs::read_to_string(&base_path)
+                .map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
+            let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
+                format!("failed to read {}: {err}", expected_patch_path.display())
+            })?;
+            let actual_patch = fs::read_to_string(&actual_patch_path)
+                .map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
+
+            let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
+            let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
+
+            EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
+        }
+        CliInput::Json {
+            json_path,
+            prediction_index,
+        } => {
+            let json = fs::read_to_string(&json_path)
+                .map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
+            let example: JsonExample = serde_json::from_str(&json)
+                .map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
+
+            let base = example.prompt_inputs.cursor_excerpt;
+            let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
+            let expected_patch = example
+                .expected_patches
+                .into_iter()
+                .next()
+                .ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
+            let actual_patch = example
+                .predictions
+                .into_iter()
+                .nth(prediction_index)
+                .ok_or_else(|| {
+                    format!("JSON input does not contain predictions[{prediction_index}]")
+                })?
+                .actual_patch;
+
+            let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
+            let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
+
+            EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
+        }
+    };
+
+    print_report(&report);
+    Ok(())
+}
+
+fn print_usage() {
+    eprintln!(
+        "Usage:\n  edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n  edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
+    );
+}
+
+enum CliInput {
+    Files {
+        base_path: std::path::PathBuf,
+        expected_patch_path: std::path::PathBuf,
+        actual_patch_path: std::path::PathBuf,
+    },
+    Json {
+        json_path: std::path::PathBuf,
+        prediction_index: usize,
+    },
+}
+
+impl CliInput {
+    fn parse(args: &[String]) -> Result<Self, String> {
+        let mut base_path = None;
+        let mut expected_patch_path = None;
+        let mut actual_patch_path = None;
+        let mut json_path = None;
+        let mut prediction_index = 0usize;
+
+        let mut index = 0;
+        while index < args.len() {
+            match args[index].as_str() {
+                "--base" => {
+                    index += 1;
+                    base_path = Some(path_arg(args, index, "--base")?);
+                }
+                "--expected-patch" => {
+                    index += 1;
+                    expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
+                }
+                "--actual-patch" => {
+                    index += 1;
+                    actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
+                }
+                "--json" => {
+                    index += 1;
+                    json_path = Some(path_arg(args, index, "--json")?);
+                }
+                "--prediction-index" => {
+                    index += 1;
+                    let raw = string_arg(args, index, "--prediction-index")?;
+                    prediction_index = raw.parse::<usize>().map_err(|err| {
+                        format!("invalid value for --prediction-index ({raw}): {err}")
+                    })?;
+                }
+                "--help" | "-h" => {
+                    print_usage();
+                    process::exit(0);
+                }
+                unknown => {
+                    return Err(format!("unrecognized argument: {unknown}"));
+                }
+            }
+            index += 1;
+        }
+
+        if let Some(json_path) = json_path {
+            if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
+                return Err(
+                    "--json cannot be combined with --base/--expected-patch/--actual-patch"
+                        .to_string(),
+                );
+            }
+            return Ok(CliInput::Json {
+                json_path,
+                prediction_index,
+            });
+        }
+
+        match (base_path, expected_patch_path, actual_patch_path) {
+            (Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
+                Ok(CliInput::Files {
+                    base_path,
+                    expected_patch_path,
+                    actual_patch_path,
+                })
+            }
+            _ => Err(
+                "expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
+                    .to_string(),
+            ),
+        }
+    }
+}
+
+fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
+    Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
+}
+
+fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
+    args.get(index)
+        .map(|value| value.as_str())
+        .ok_or_else(|| format!("missing value for {flag}"))
+}
+
+#[derive(Debug)]
+struct EvaluationReport {
+    base: String,
+    expected: String,
+    actual: String,
+    kept_rate: KeptRateResult,
+    exact_lines: ClassificationMetrics,
+    delta_chr_f: DeltaChrFMetrics,
+    expected_changed_lines: usize,
+    actual_changed_lines: usize,
+    token_changes: edit_prediction_metrics::TokenChangeCounts,
+    isolated_whitespace_changes: bool,
+    editable_region_correct: bool,
+    expected_braces_disbalance: usize,
+    actual_braces_disbalance: usize,
+}
+
+impl EvaluationReport {
+    fn new(
+        base: String,
+        expected_patch: String,
+        actual_patch: String,
+        expected: String,
+        actual: String,
+    ) -> Self {
+        let kept_rate = compute_kept_rate(&base, &actual, &expected);
+        let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
+        let delta_chr_f = delta_chr_f(&base, &expected, &actual);
+        let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
+            .values()
+            .sum();
+        let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
+            .values()
+            .sum();
+        let token_changes = count_patch_token_changes(&actual_patch);
+        let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
+        let editable_region_correct = is_editable_region_correct(&actual_patch);
+        let expected_braces_disbalance = braces_disbalance(&expected);
+        let actual_braces_disbalance = braces_disbalance(&actual);
+
+        Self {
+            base,
+            expected,
+            actual,
+            kept_rate,
+            exact_lines,
+            delta_chr_f,
+            expected_changed_lines,
+            actual_changed_lines,
+            token_changes,
+            isolated_whitespace_changes,
+            editable_region_correct,
+            expected_braces_disbalance,
+            actual_braces_disbalance,
+        }
+    }
+}
+
+fn print_report(report: &EvaluationReport) {
+    println!("Metrics");
+    println!("=======");
+    println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
+    println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
+    println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
+    println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
+    println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
+    println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
+    println!();
+
+    println!("Exact line match");
+    println!("----------------");
+    println!("true_positives: {}", report.exact_lines.true_positives);
+    println!("false_positives: {}", report.exact_lines.false_positives);
+    println!("false_negatives: {}", report.exact_lines.false_negatives);
+    println!("precision: {:.6}", report.exact_lines.precision());
+    println!("recall: {:.6}", report.exact_lines.recall());
+    println!("f1: {:.6}", report.exact_lines.f1());
+    println!("expected_changed_lines: {}", report.expected_changed_lines);
+    println!("actual_changed_lines: {}", report.actual_changed_lines);
+    println!();
+
+    println!("Patch structure");
+    println!("---------------");
+    println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
+    println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
+    println!(
+        "isolated_whitespace_changes: {}",
+        report.isolated_whitespace_changes
+    );
+    println!(
+        "editable_region_correct: {}",
+        report.editable_region_correct
+    );
+    println!();
+
+    println!("Final text checks");
+    println!("-----------------");
+    println!(
+        "expected_braces_disbalance: {}",
+        report.expected_braces_disbalance
+    );
+    println!(
+        "actual_braces_disbalance: {}",
+        report.actual_braces_disbalance
+    );
+    println!();
+
+    println!("Kept-rate breakdown");
+    println!("-------------------");
+    println!(
+        "candidate_new_chars: {}",
+        report.kept_rate.candidate_new_chars
+    );
+    println!(
+        "reference_new_chars: {}",
+        report.kept_rate.reference_new_chars
+    );
+    println!(
+        "candidate_deleted_chars: {}",
+        report.kept_rate.candidate_deleted_chars
+    );
+    println!(
+        "reference_deleted_chars: {}",
+        report.kept_rate.reference_deleted_chars
+    );
+    println!("kept_chars: {}", report.kept_rate.kept_chars);
+    println!(
+        "correctly_deleted_chars: {}",
+        report.kept_rate.correctly_deleted_chars
+    );
+    println!("discarded_chars: {}", report.kept_rate.discarded_chars);
+    println!("context_chars: {}", report.kept_rate.context_chars);
+    println!();
+
+    print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
+}
+
+fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
+    println!("Kept-rate explanation");
+    println!("---------------------");
+    println!("Legend: context = default, kept = green background, discarded = red background");
+    println!();
+
+    let annotated = annotate_kept_rate_tokens(base, actual, expected);
+    println!("Actual final text with token annotations:");
+    println!("{}", render_annotated_tokens(&annotated));
+    println!();
+}
+
+fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
+    const RESET: &str = "\x1b[0m";
+    const KEPT_STYLE: &str = "\x1b[30;42m";
+    const DISCARDED_STYLE: &str = "\x1b[30;41m";
+
+    let mut rendered = String::new();
+    for token in tokens {
+        let style = match token.annotation {
+            TokenAnnotation::Context => "",
+            TokenAnnotation::Kept => KEPT_STYLE,
+            TokenAnnotation::Discarded => DISCARDED_STYLE,
+        };
+
+        if style.is_empty() {
+            rendered.push_str(&visualize_whitespace(&token.token));
+        } else {
+            rendered.push_str(style);
+            rendered.push_str(&visualize_whitespace(&token.token));
+            rendered.push_str(RESET);
+        }
+    }
+    rendered
+}
+
+fn visualize_whitespace(token: &str) -> String {
+    let mut rendered = String::new();
+    for ch in token.chars() {
+        match ch {
+            ' ' => rendered.push('ยท'),
+            '\t' => rendered.push('โ‡ฅ'),
+            '\n' => rendered.push_str("โ†ต\n"),
+            _ => rendered.push(ch),
+        }
+    }
+    rendered
+}
+
+#[derive(Debug, Deserialize)]
+struct JsonExample {
+    prompt_inputs: PromptInputs,
+    expected_patches: Vec<String>,
+    predictions: Vec<Prediction>,
+}
+
+#[derive(Debug, Deserialize)]
+struct PromptInputs {
+    cursor_excerpt: String,
+    excerpt_start_row: u32,
+}
+
+#[derive(Debug, Deserialize)]
+struct Prediction {
+    actual_patch: String,
+}
+
+#[derive(Debug, Clone)]
+struct ParsedHunk {
+    old_start: u32,
+    lines: Vec<HunkLine>,
+}
+
+#[derive(Debug, Clone)]
+enum HunkLine {
+    Context(String),
+    Addition(String),
+    Deletion(String),
+}
+
+fn apply_patch_to_excerpt(
+    base: &str,
+    patch: &str,
+    excerpt_start_row: u32,
+) -> Result<String, String> {
+    let hunks = parse_diff_hunks(patch);
+
+    let result = try_apply_hunks(base, &hunks, excerpt_start_row);
+
+    // Predicted patches may use excerpt-relative line numbers instead of
+    // file-global ones. When all hunks fall outside the excerpt window the
+    // result is identical to the base text. Retry with a zero offset so the
+    // line numbers are interpreted relative to the excerpt.
+    if excerpt_start_row > 0 && !hunks.is_empty() {
+        let should_retry = match &result {
+            Ok(text) => text == base,
+            Err(_) => true,
+        };
+
+        if should_retry {
+            let fallback = try_apply_hunks(base, &hunks, 0);
+            if matches!(&fallback, Ok(text) if text != base) {
+                return fallback;
+            }
+        }
+    }
+
+    result
+}
+
+fn try_apply_hunks(
+    base: &str,
+    hunks: &[ParsedHunk],
+    excerpt_start_row: u32,
+) -> Result<String, String> {
+    let base_has_trailing_newline = base.ends_with('\n');
+    let mut lines = split_preserving_final_empty_line(base);
+    let original_line_count = lines.len() as u32;
+
+    let excerpt_end_row = excerpt_start_row + original_line_count;
+    let mut line_delta: i64 = 0;
+
+    for hunk in hunks {
+        let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
+            Some(filtered) => filtered,
+            None => continue,
+        };
+
+        let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
+        if local_start < 0 {
+            return Err(format!(
+                "patch application moved before excerpt start at source row {}",
+                filtered.old_start
+            ));
+        }
+        let local_start = local_start as usize;
+
+        if local_start > lines.len() {
+            return Err(format!(
+                "patch application starts past excerpt end at local line {}",
+                local_start + 1
+            ));
+        }
+
+        let old_len = filtered
+            .lines
+            .iter()
+            .filter(|line| !matches!(line, HunkLine::Addition(_)))
+            .count();
+        let new_len = filtered
+            .lines
+            .iter()
+            .filter(|line| !matches!(line, HunkLine::Deletion(_)))
+            .count();
+
+        let old_segment: Vec<&str> = filtered
+            .lines
+            .iter()
+            .filter_map(|line| match line {
+                HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
+                HunkLine::Addition(_) => None,
+            })
+            .collect();
+
+        let new_segment: Vec<String> = filtered
+            .lines
+            .iter()
+            .filter_map(|line| match line {
+                HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
+                HunkLine::Deletion(_) => None,
+            })
+            .collect();
+
+        if local_start + old_len > lines.len() {
+            return Err(format!(
+                "patch application exceeds excerpt bounds near source row {}",
+                filtered.old_start
+            ));
+        }
+
+        let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
+            .iter()
+            .map(String::as_str)
+            .collect();
+
+        if current_segment != old_segment {
+            let mut details = String::new();
+            let _ = write!(
+                details,
+                "patch context mismatch near source row {}: expected {:?}, found {:?}",
+                filtered.old_start, old_segment, current_segment
+            );
+            return Err(details);
+        }
+
+        lines.splice(local_start..local_start + old_len, new_segment);
+        line_delta += new_len as i64 - old_len as i64;
+    }
+
+    Ok(join_lines(&lines, base_has_trailing_newline))
+}
+
+fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
+    let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
+    if text.ends_with('\n') {
+        if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
+            lines.push(String::new());
+        }
+    }
+    lines
+}
+
+fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
+    if lines.is_empty() {
+        return String::new();
+    }
+
+    let mut joined = lines.join("\n");
+    if had_trailing_newline && !joined.ends_with('\n') {
+        joined.push('\n');
+    }
+    if !had_trailing_newline && joined.ends_with('\n') {
+        joined.pop();
+    }
+    joined
+}
+
+fn filter_hunk_to_excerpt(
+    hunk: &ParsedHunk,
+    excerpt_start_row: u32,
+    excerpt_end_row: u32,
+) -> Option<ParsedHunk> {
+    let mut filtered_lines = Vec::new();
+    let mut current_old_row = hunk.old_start.saturating_sub(1);
+    let mut filtered_old_start = None;
+    let mut has_overlap = false;
+
+    for line in &hunk.lines {
+        match line {
+            HunkLine::Context(text) => {
+                let in_excerpt =
+                    current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
+                if in_excerpt {
+                    filtered_old_start.get_or_insert(current_old_row);
+                    filtered_lines.push(HunkLine::Context(text.clone()));
+                    has_overlap = true;
+                }
+                current_old_row += 1;
+            }
+            HunkLine::Deletion(text) => {
+                let in_excerpt =
+                    current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
+                if in_excerpt {
+                    filtered_old_start.get_or_insert(current_old_row);
+                    filtered_lines.push(HunkLine::Deletion(text.clone()));
+                    has_overlap = true;
+                }
+                current_old_row += 1;
+            }
+            HunkLine::Addition(text) => {
+                let insertion_in_excerpt =
+                    current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
+                if insertion_in_excerpt {
+                    filtered_old_start.get_or_insert(current_old_row);
+                    filtered_lines.push(HunkLine::Addition(text.clone()));
+                    has_overlap = true;
+                }
+            }
+        }
+    }
+
+    if !has_overlap {
+        return None;
+    }
+
+    Some(ParsedHunk {
+        old_start: filtered_old_start.unwrap_or(excerpt_start_row),
+        lines: filtered_lines,
+    })
+}
+
+fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
+    let mut hunks = Vec::new();
+    let mut current_hunk: Option<ParsedHunk> = None;
+
+    for line in diff.lines() {
+        if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
+            if let Some(hunk) = current_hunk.take() {
+                hunks.push(hunk);
+            }
+            let _ = old_count;
+            current_hunk = Some(ParsedHunk {
+                old_start,
+                lines: Vec::new(),
+            });
+            continue;
+        }
+
+        let Some(hunk) = current_hunk.as_mut() else {
+            continue;
+        };
+
+        if let Some(text) = line.strip_prefix('+') {
+            if !line.starts_with("+++") {
+                hunk.lines.push(HunkLine::Addition(text.to_string()));
+            }
+        } else if let Some(text) = line.strip_prefix('-') {
+            if !line.starts_with("---") {
+                hunk.lines.push(HunkLine::Deletion(text.to_string()));
+            }
+        } else if let Some(text) = line.strip_prefix(' ') {
+            hunk.lines.push(HunkLine::Context(text.to_string()));
+        } else if line.is_empty() {
+            hunk.lines.push(HunkLine::Context(String::new()));
+        }
+    }
+
+    if let Some(hunk) = current_hunk {
+        hunks.push(hunk);
+    }
+
+    hunks
+}
+
+fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
+    let line = line.strip_prefix("@@ -")?;
+    let (old_part, rest) = line.split_once(' ')?;
+    let rest = rest.strip_prefix('+')?;
+    let (new_part, _) = rest.split_once(" @@")?;
+
+    let (old_start, old_count) = parse_hunk_range(old_part)?;
+    let (new_start, new_count) = parse_hunk_range(new_part)?;
+    Some((old_start, old_count, new_start, new_count))
+}
+
+fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
+    if let Some((start, count)) = part.split_once(',') {
+        Some((start.parse().ok()?, count.parse().ok()?))
+    } else {
+        Some((part.parse().ok()?, 1))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn applies_patch_in_file_mode() {
+        let base = "fn main() {\n    println!(\"hello\");\n}\n";
+        let patch = "@@ -1,3 +1,3 @@\n fn main() {\n-    println!(\"hello\");\n+    println!(\"world\");\n }\n";
+
+        let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
+        assert_eq!(actual, "fn main() {\n    println!(\"world\");\n}\n");
+    }
+
+    #[test]
+    fn applies_patch_in_json_excerpt_mode() {
+        let base = "b\nc\nd\n";
+        let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
+
+        let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
+        assert_eq!(actual, "x\ny\nd\n");
+    }
+
+    #[test]
+    fn applies_patch_with_excerpt_relative_line_numbers() {
+        let base = "a\nb\nc\nd\n";
+        // Patch uses excerpt-relative line numbers (line 2 of excerpt)
+        // even though the excerpt starts at file row 100.
+        let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
+
+        let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
+        assert_eq!(actual, "a\nx\ny\nd\n");
+    }
+
+    #[test]
+    fn prefers_file_global_line_numbers_over_excerpt_relative() {
+        let base = "a\nb\nc\n";
+        // Patch uses file-global line numbers: excerpt starts at row 5,
+        // hunk targets line 6 (1-based) = row 5 (0-based) = first line.
+        let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
+
+        let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
+        assert_eq!(actual, "x\ny\nc\n");
+    }
+}

crates/edit_prediction_metrics/src/patch_metrics.rs ๐Ÿ”—

@@ -687,6 +687,35 @@ fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
         .collect()
 }
 
+/// Reconstruct old and new text from a unified diff.
+///
+/// Context and deletion lines form the old text; context and addition
+/// lines form the new text. Returns `(old_text, new_text)`.
+pub fn reconstruct_texts_from_diff(patch_str: &str) -> (String, String) {
+    let patch = Patch::parse_unified_diff(patch_str);
+    let mut old_lines: Vec<&str> = Vec::new();
+    let mut new_lines: Vec<&str> = Vec::new();
+
+    for hunk in &patch.hunks {
+        for line in &hunk.lines {
+            match line {
+                PatchLine::Context(content) => {
+                    old_lines.push(content);
+                    new_lines.push(content);
+                }
+                PatchLine::Deletion(content) => {
+                    old_lines.push(content);
+                }
+                PatchLine::Addition(content) => {
+                    new_lines.push(content);
+                }
+                PatchLine::Garbage(_) => {}
+            }
+        }
+    }
+
+    (old_lines.join("\n"), new_lines.join("\n"))
+}
 #[derive(Debug, Default, Clone)]
 struct Patch {
     hunks: Vec<Hunk>,

crates/edit_prediction_metrics/src/tokenize.rs ๐Ÿ”—

@@ -1,33 +1,158 @@
-fn char_class(character: char) -> u8 {
-    if character.is_alphanumeric() || character == '_' {
-        0
+use std::{iter::Peekable, str::CharIndices};
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+enum CharClass {
+    Identifier,
+    Newline,
+    Whitespace,
+    Punctuation,
+}
+
+const MULTI_CHAR_PUNCTUATION: &[&str] = &[
+    ">>>=", "<<=", ">>=", "...", "..=", "??=", "**=", ">>>", "::", "->", "=>", "==", "!=", "<=",
+    ">=", "&&", "||", "<<", ">>", "..", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "++", "--",
+    "**", "??", "?.", ":=", "<-", "//", "/*", "*/",
+];
+
+fn char_class(character: char) -> CharClass {
+    if character == '\n' || character == '\r' {
+        CharClass::Newline
     } else if character.is_whitespace() {
-        1
+        CharClass::Whitespace
+    } else if character.is_alphanumeric() || character == '_' {
+        CharClass::Identifier
     } else {
-        2
+        CharClass::Punctuation
     }
 }
 
+fn is_identifier_boundary(previous: char, current: char, next: Option<char>) -> bool {
+    (current.is_uppercase() && (previous.is_lowercase() || previous.is_numeric()))
+        || (current.is_uppercase()
+            && previous.is_uppercase()
+            && next.is_some_and(|next| next.is_lowercase()))
+}
+
+fn push_identifier_tokens<'a>(identifier: &'a str, tokens: &mut Vec<&'a str>) {
+    let characters: Vec<(usize, char)> = identifier.char_indices().collect();
+    let mut segment_start = 0;
+    let mut index = 0;
+
+    while index < characters.len() {
+        let (byte_index, character) = characters[index];
+
+        if character == '_' {
+            if segment_start < byte_index {
+                tokens.push(&identifier[segment_start..byte_index]);
+            }
+
+            let mut underscore_end = byte_index + character.len_utf8();
+            index += 1;
+
+            while index < characters.len() && characters[index].1 == '_' {
+                underscore_end = characters[index].0 + characters[index].1.len_utf8();
+                index += 1;
+            }
+
+            tokens.push(&identifier[byte_index..underscore_end]);
+            segment_start = underscore_end;
+            continue;
+        }
+
+        if byte_index > segment_start {
+            let previous = characters[index - 1].1;
+            let next = characters.get(index + 1).map(|(_, character)| *character);
+
+            if is_identifier_boundary(previous, character, next) {
+                tokens.push(&identifier[segment_start..byte_index]);
+                segment_start = byte_index;
+            }
+        }
+
+        index += 1;
+    }
+
+    if segment_start < identifier.len() {
+        tokens.push(&identifier[segment_start..]);
+    }
+}
+
+fn push_punctuation_token<'a>(
+    text: &'a str,
+    start: usize,
+    character: char,
+    characters: &mut Peekable<CharIndices<'a>>,
+    tokens: &mut Vec<&'a str>,
+) {
+    let remaining = &text[start..];
+
+    for punctuation in MULTI_CHAR_PUNCTUATION {
+        if remaining.starts_with(punctuation) {
+            for _ in punctuation.chars().skip(1) {
+                characters.next();
+            }
+
+            tokens.push(&remaining[..punctuation.len()]);
+            return;
+        }
+    }
+
+    let end = start + character.len_utf8();
+    tokens.push(&text[start..end]);
+}
+
 pub(crate) fn tokenize(text: &str) -> Vec<&str> {
     let mut tokens = Vec::new();
     let mut characters = text.char_indices().peekable();
 
     while let Some((start, character)) = characters.next() {
-        let class = char_class(character);
-        if class == 2 {
-            tokens.push(&text[start..start + character.len_utf8()]);
-            continue;
-        }
+        match char_class(character) {
+            CharClass::Identifier => {
+                let mut end = start + character.len_utf8();
+
+                while let Some(&(next_start, next_character)) = characters.peek() {
+                    if char_class(next_character) != CharClass::Identifier {
+                        break;
+                    }
+
+                    end = next_start + next_character.len_utf8();
+                    characters.next();
+                }
+
+                push_identifier_tokens(&text[start..end], &mut tokens);
+            }
+            CharClass::Newline => {
+                let mut end = start + character.len_utf8();
+
+                while let Some(&(next_start, next_character)) = characters.peek() {
+                    if char_class(next_character) != CharClass::Newline {
+                        break;
+                    }
+
+                    end = next_start + next_character.len_utf8();
+                    characters.next();
+                }
 
-        let mut end = start + character.len_utf8();
-        while let Some(&(_, next_character)) = characters.peek() {
-            if char_class(next_character) != class {
-                break;
+                tokens.push(&text[start..end]);
+            }
+            CharClass::Whitespace => {
+                let mut end = start + character.len_utf8();
+
+                while let Some(&(next_start, next_character)) = characters.peek() {
+                    if char_class(next_character) != CharClass::Whitespace {
+                        break;
+                    }
+
+                    end = next_start + next_character.len_utf8();
+                    characters.next();
+                }
+
+                tokens.push(&text[start..end]);
+            }
+            CharClass::Punctuation => {
+                push_punctuation_token(text, start, character, &mut characters, &mut tokens);
             }
-            end += next_character.len_utf8();
-            characters.next();
         }
-        tokens.push(&text[start..end]);
     }
 
     tokens
@@ -38,17 +163,58 @@ mod tests {
     use super::tokenize;
 
     #[test]
-    fn tokenizes_code_like_text() {
+    fn tokenizes_code() {
         assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
         assert_eq!(
             tokenize("foo_bar123 + baz"),
-            vec!["foo_bar123", " ", "+", " ", "baz"]
+            vec!["foo", "_", "bar123", " ", "+", " ", "baz"]
         );
         assert_eq!(
             tokenize("print(\"hello\")"),
             vec!["print", "(", "\"", "hello", "\"", ")"]
         );
-        assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
+        assert_eq!(tokenize("hello_world"), vec!["hello", "_", "world"]);
         assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
     }
+
+    #[test]
+    fn tokenizes_identifier_case_styles() {
+        assert_eq!(
+            tokenize("camelCase PascalCase snake_case"),
+            vec![
+                "camel", "Case", " ", "Pascal", "Case", " ", "snake", "_", "case"
+            ]
+        );
+        assert_eq!(
+            tokenize("myHTTPServer __private_value foo__bar"),
+            vec![
+                "my", "HTTP", "Server", " ", "__", "private", "_", "value", " ", "foo", "__", "bar"
+            ]
+        );
+        assert_eq!(
+            tokenize("XMLHttpRequest Version2Update"),
+            vec!["XML", "Http", "Request", " ", "Version2", "Update"]
+        );
+    }
+
+    #[test]
+    fn tokenizes_grouped_punctuation() {
+        assert_eq!(
+            tokenize("a::b -> c != d ..= e"),
+            vec![
+                "a", "::", "b", " ", "->", " ", "c", " ", "!=", " ", "d", " ", "..=", " ", "e"
+            ]
+        );
+        assert_eq!(
+            tokenize("foo?.bar ?? baz"),
+            vec!["foo", "?.", "bar", " ", "??", " ", "baz"]
+        );
+    }
+
+    #[test]
+    fn tokenize_whitespace_runs() {
+        assert_eq!(tokenize("  "), vec!["  "]);
+        assert_eq!(tokenize("  \n   foo"), vec!["  ", "\n", "   ", "foo"]);
+        assert_eq!(tokenize("\r\n\nfoo"), vec!["\r\n\n", "foo"]);
+    }
 }

typos.toml ๐Ÿ”—

@@ -63,6 +63,7 @@ extend-exclude = [
     "crates/gpui_macos/src/dispatcher.rs",
     # Tests contain partially incomplete words (by design)
     "crates/edit_prediction_cli/src/split_commit.rs",
+    "crates/edit_prediction_metrics/src/kept_rate.rs",
     # Eval examples contain intentionally partial words (e.g. "secur" for "secure")
     "crates/edit_prediction_cli/evals/",
     # Tests contain `baห‡r` that cause `"ba" should be "by" or "be".`-like false-positives