Cargo.lock ๐
@@ -5338,6 +5338,7 @@ dependencies = [
"language",
"pretty_assertions",
"serde",
+ "serde_json",
"similar",
"tree-sitter",
"zeta_prompt",
Oleksiy Syvokon created
This change contains a number of fixes to make kept_rate more intuitive.
It also adds a CLI utility to print debug info on how the metric is
computed.
Release Notes:
- N/A
Cargo.lock | 1
crates/edit_prediction_metrics/Cargo.toml | 1
crates/edit_prediction_metrics/src/edit_prediction_metrics.rs | 4
crates/edit_prediction_metrics/src/kept_rate.rs | 283 +
crates/edit_prediction_metrics/src/main.rs | 710 +++++
crates/edit_prediction_metrics/src/patch_metrics.rs | 29
crates/edit_prediction_metrics/src/tokenize.rs | 206 +
typos.toml | 1
8 files changed, 1,145 insertions(+), 90 deletions(-)
@@ -5338,6 +5338,7 @@ dependencies = [
"language",
"pretty_assertions",
"serde",
+ "serde_json",
"similar",
"tree-sitter",
"zeta_prompt",
@@ -14,6 +14,7 @@ path = "src/edit_prediction_metrics.rs"
[dependencies]
language.workspace = true
serde.workspace = true
+serde_json = "1.0"
similar = "2.7.0"
tree-sitter.workspace = true
zeta_prompt.workspace = true
@@ -4,9 +4,10 @@ mod reversal;
mod tokenize;
mod tree_sitter;
+pub use kept_rate::AnnotatedToken;
pub use kept_rate::KeptRateResult;
-#[cfg(test)]
pub use kept_rate::TokenAnnotation;
+pub use kept_rate::annotate_kept_rate_tokens;
pub use kept_rate::compute_kept_rate;
pub use patch_metrics::ClassificationMetrics;
pub use patch_metrics::Counts;
@@ -20,5 +21,6 @@ pub use patch_metrics::exact_lines_match;
pub use patch_metrics::extract_changed_lines_from_diff;
pub use patch_metrics::has_isolated_whitespace_changes;
pub use patch_metrics::is_editable_region_correct;
+pub use patch_metrics::reconstruct_texts_from_diff;
pub use reversal::compute_prediction_reversal_ratio_from_history;
pub use tree_sitter::count_tree_sitter_errors;
@@ -3,14 +3,20 @@ use serde::Serialize;
const MAX_DIRTY_LENGTH_DELTA_CHARS: usize = 512;
-#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
+#[serde(rename_all = "snake_case")]
pub enum TokenAnnotation {
Context,
Kept,
Discarded,
}
+#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
+pub struct AnnotatedToken {
+ pub token: String,
+ pub annotation: TokenAnnotation,
+}
+
#[allow(dead_code)]
#[derive(Debug, Clone, Serialize)]
pub struct KeptRateResult {
@@ -40,8 +46,7 @@ pub struct KeptRateResult {
/// This includes both kept newly introduced characters and correctly
/// deleted base characters.
pub recall_rate: f64,
- /// Per-token classification for candidate tokens used by tests.
- #[cfg(test)]
+ /// Per-token classification for candidate tokens.
pub token_annotations: Vec<TokenAnnotation>,
}
@@ -51,9 +56,9 @@ fn dp_index(width: usize, row: usize, column: usize) -> usize {
/// Fill masks over `a` and `b` using one-sided LCS tie-breaking for each side
/// while sharing a single DP table construction.
-fn fill_lcs_keep_masks(
- a: &[&str],
- b: &[&str],
+fn fill_lcs_keep_masks<T: Eq>(
+ a: &[T],
+ b: &[T],
mut keep_a: Option<&mut [bool]>,
mut keep_b: Option<&mut [bool]>,
) {
@@ -124,10 +129,10 @@ fn fill_lcs_keep_masks(
let mut dp = vec![0u32; row_count * column_count];
for i in 1..row_count {
- let token_a = a_mid[i - 1];
+ let token_a = &a_mid[i - 1];
for j in 1..column_count {
let index = dp_index(column_count, i, j);
- if token_a == b_mid[j - 1] {
+ if token_a == &b_mid[j - 1] {
dp[index] = dp[dp_index(column_count, i - 1, j - 1)] + 1;
} else {
let up = dp[dp_index(column_count, i - 1, j)];
@@ -180,41 +185,91 @@ fn fill_lcs_keep_masks(
}
}
-fn lcs_keep_mask(a: &[&str], b: &[&str]) -> Vec<bool> {
+fn lcs_keep_mask<T: Eq>(a: &[T], b: &[T]) -> Vec<bool> {
let mut keep_a = vec![false; a.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), None);
keep_a
}
-fn lcs_keep_masks(a: &[&str], b: &[&str]) -> (Vec<bool>, Vec<bool>) {
+fn lcs_keep_masks<T: Eq>(a: &[T], b: &[T]) -> (Vec<bool>, Vec<bool>) {
let mut keep_a = vec![false; a.len()];
let mut keep_b = vec![false; b.len()];
fill_lcs_keep_masks(a, b, Some(&mut keep_a), Some(&mut keep_b));
(keep_a, keep_b)
}
-fn analyze_masked_tokens<'a>(tokens: &[&'a str], mask: &[bool]) -> (Vec<&'a str>, usize, usize) {
- let mut unmasked_tokens = Vec::with_capacity(tokens.len());
+#[derive(Debug, Clone)]
+struct ComparisonUnit {
+ text: String,
+ token_start: usize,
+ token_end: usize,
+}
+
+fn is_identifier_token(token: &str) -> bool {
+ !token.is_empty()
+ && token
+ .chars()
+ .all(|character| character.is_alphanumeric() || character == '_')
+}
+
+fn build_comparison_units(tokens: &[&str]) -> Vec<ComparisonUnit> {
+ let mut units = Vec::new();
+ let mut index = 0;
+
+ while index < tokens.len() {
+ let token_start = index;
+
+ if is_identifier_token(tokens[index]) {
+ let mut text = String::new();
+
+ while index < tokens.len() && is_identifier_token(tokens[index]) {
+ text.push_str(tokens[index]);
+ index += 1;
+ }
+
+ units.push(ComparisonUnit {
+ text,
+ token_start,
+ token_end: index,
+ });
+ } else {
+ units.push(ComparisonUnit {
+ text: tokens[index].to_string(),
+ token_start,
+ token_end: index + 1,
+ });
+ index += 1;
+ }
+ }
+
+ units
+}
+
+fn analyze_masked_units<'a>(
+ units: &'a [ComparisonUnit],
+ mask: &[bool],
+) -> (Vec<&'a str>, usize, usize) {
+ let mut unmasked_units = Vec::with_capacity(units.len());
let mut unmasked_chars = 0;
let mut masked_chars = 0;
- for (&token, &is_masked) in tokens.iter().zip(mask.iter()) {
+ for (unit, &is_masked) in units.iter().zip(mask.iter()) {
if is_masked {
- masked_chars += token.len();
+ masked_chars += unit.text.len();
} else {
- unmasked_tokens.push(token);
- unmasked_chars += token.len();
+ unmasked_units.push(unit.text.as_str());
+ unmasked_chars += unit.text.len();
}
}
- (unmasked_tokens, unmasked_chars, masked_chars)
+ (unmasked_units, unmasked_chars, masked_chars)
}
-fn count_unmasked_chars(tokens: &[&str], mask: &[bool]) -> usize {
- tokens
+fn count_unmasked_unit_chars(units: &[ComparisonUnit], mask: &[bool]) -> usize {
+ units
.iter()
.zip(mask.iter())
- .filter_map(|(&token, &is_masked)| (!is_masked).then_some(token.len()))
+ .filter_map(|(unit, &is_masked)| (!is_masked).then_some(unit.text.len()))
.sum()
}
@@ -239,7 +294,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars,
kept_rate: 1.0,
recall_rate: 1.0,
- #[cfg(test)]
token_annotations: vec![TokenAnnotation::Context; candidate_tokens.len()],
};
}
@@ -258,7 +312,6 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars: 0,
kept_rate: 0.0,
recall_rate: 0.0,
- #[cfg(test)]
token_annotations: vec![TokenAnnotation::Discarded; tokenize(candidate).len()],
};
}
@@ -267,29 +320,29 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
let candidate_tokens = tokenize(candidate);
let reference_tokens = tokenize(reference);
- let (candidate_base_mask, base_candidate_mask) =
- lcs_keep_masks(&candidate_tokens, &base_tokens);
- let (candidate_reference_mask, reference_candidate_mask) =
- lcs_keep_masks(&candidate_tokens, &reference_tokens);
- let context_mask: Vec<bool> = candidate_base_mask
+ let candidate_units = build_comparison_units(&candidate_tokens);
+ let base_units = build_comparison_units(&base_tokens);
+ let reference_units = build_comparison_units(&reference_tokens);
+
+ let candidate_unit_texts: Vec<&str> = candidate_units
.iter()
- .zip(candidate_reference_mask.iter())
- .map(|(&in_base, &in_reference)| in_base && in_reference)
+ .map(|unit| unit.text.as_str())
+ .collect();
+ let base_unit_texts: Vec<&str> = base_units.iter().map(|unit| unit.text.as_str()).collect();
+ let reference_unit_texts: Vec<&str> = reference_units
+ .iter()
+ .map(|unit| unit.text.as_str())
.collect();
+ let (candidate_base_mask, base_candidate_mask) =
+ lcs_keep_masks(&candidate_unit_texts, &base_unit_texts);
let (stripped_candidate, candidate_new_chars, context_chars) =
- analyze_masked_tokens(&candidate_tokens, &context_mask);
+ analyze_masked_units(&candidate_units, &candidate_base_mask);
let (reference_base_mask, base_reference_mask) =
- lcs_keep_masks(&reference_tokens, &base_tokens);
- let reference_context_mask: Vec<bool> = reference_base_mask
- .iter()
- .zip(reference_candidate_mask.iter())
- .map(|(&in_base, &in_candidate)| in_base && in_candidate)
- .collect();
-
+ lcs_keep_masks(&reference_unit_texts, &base_unit_texts);
let (stripped_reference, reference_new_chars, _) =
- analyze_masked_tokens(&reference_tokens, &reference_context_mask);
+ analyze_masked_units(&reference_units, &reference_base_mask);
let keep_mask = lcs_keep_mask(&stripped_candidate, &stripped_reference);
@@ -299,13 +352,13 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
.filter_map(|(&token, &is_kept)| is_kept.then_some(token.len()))
.sum();
- let candidate_deleted_chars = count_unmasked_chars(&base_tokens, &base_candidate_mask);
- let reference_deleted_chars = count_unmasked_chars(&base_tokens, &base_reference_mask);
- let correctly_deleted_chars: usize = base_tokens
+ let candidate_deleted_chars = count_unmasked_unit_chars(&base_units, &base_candidate_mask);
+ let reference_deleted_chars = count_unmasked_unit_chars(&base_units, &base_reference_mask);
+ let correctly_deleted_chars: usize = base_units
.iter()
.zip(base_candidate_mask.iter().zip(base_reference_mask.iter()))
- .filter_map(|(&token, (&in_candidate, &in_reference))| {
- (!in_candidate && !in_reference).then_some(token.len())
+ .filter_map(|(unit, (&in_candidate, &in_reference))| {
+ (!in_candidate && !in_reference).then_some(unit.text.len())
})
.sum();
@@ -326,24 +379,28 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
matched_edit_chars as f64 / reference_edit_chars as f64
};
- #[cfg(test)]
let token_annotations = {
- let mut token_annotations = Vec::with_capacity(candidate_tokens.len());
+ let mut token_annotations = vec![TokenAnnotation::Context; candidate_tokens.len()];
let mut new_index = 0;
- for (token_index, _token) in candidate_tokens.iter().enumerate() {
- if context_mask[token_index] {
- token_annotations.push(TokenAnnotation::Context);
+
+ for (unit_index, unit) in candidate_units.iter().enumerate() {
+ let annotation = if candidate_base_mask[unit_index] {
+ TokenAnnotation::Context
} else {
let annotation = if keep_mask[new_index] {
TokenAnnotation::Kept
} else {
TokenAnnotation::Discarded
};
- #[cfg(test)]
- token_annotations.push(annotation);
new_index += 1;
+ annotation
+ };
+
+ for token_index in unit.token_start..unit.token_end {
+ token_annotations[token_index] = annotation;
}
}
+
token_annotations
};
@@ -358,14 +415,30 @@ pub fn compute_kept_rate(base: &str, candidate: &str, reference: &str) -> KeptRa
context_chars,
kept_rate,
recall_rate,
- #[cfg(test)]
token_annotations,
}
}
+pub fn annotate_kept_rate_tokens(
+ base: &str,
+ candidate: &str,
+ reference: &str,
+) -> Vec<AnnotatedToken> {
+ let result = compute_kept_rate(base, candidate, reference);
+ tokenize(candidate)
+ .into_iter()
+ .zip(result.token_annotations)
+ .map(|(token, annotation)| AnnotatedToken {
+ token: token.to_string(),
+ annotation,
+ })
+ .collect()
+}
+
#[cfg(test)]
mod test_kept_rate {
use super::*;
+ use indoc::indoc;
#[test]
fn test_lcs_keep_masks() {
@@ -439,16 +512,24 @@ mod test_kept_rate {
#[test]
fn test_missing_deletion() {
- let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
- let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\neprintln!(\"\");\n";
- let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let base = indoc! {"
+ fn example() {
+ epr
+ "};
+ let candidate = indoc! {r#"
+ fn example() {
+ epr
+ eprintln!("");
+ "#};
+ let reference = indoc! {r#"
+ fn example() {
+ eprintln!("");
+ "#};
+
let result = compute_kept_rate(base, candidate, reference);
- assert!(
- result.kept_rate < 0.85,
- "expected kept_rate < 0.85, got {}",
- result.kept_rate
- );
- assert!(result.discarded_chars > 0);
+ assert!((result.kept_rate - (14.0 / 15.0)).abs() < 1e-6);
+ assert_eq!(result.kept_chars, 14);
+ assert_eq!(result.discarded_chars, 1);
}
#[test]
@@ -472,8 +553,17 @@ mod test_kept_rate {
#[test]
fn test_bails_for_dirty_final() {
- let base = "fn example() {\n work();\n}\n";
- let candidate = "fn example() {\n work();\n predicted();\n}\n";
+ let base = indoc! {"
+ fn example() {
+ work();
+ }
+ "};
+ let candidate = indoc! {"
+ fn example() {
+ work();
+ predicted();
+ }
+ "};
let reference = format!(
"fn example() {{\n work();\n {}\n}}\n",
"settled();\n ".repeat(MAX_DIRTY_LENGTH_DELTA_CHARS / 8 + 64)
@@ -488,9 +578,19 @@ mod test_kept_rate {
#[test]
fn test_eprintln_token_alignment() {
- let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
- let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
- let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let base = indoc! {"
+ fn example() {
+ epr
+ "};
+ let candidate = indoc! {r#"
+ fn example() {
+ eprintln!("hello world!");
+ "#};
+ let reference = indoc! {r#"
+ fn example() {
+ eprintln!("");
+ "#};
+
let result = compute_kept_rate(base, candidate, reference);
assert!(result.discarded_chars > 0);
assert!(result.kept_chars > 0);
@@ -499,6 +599,42 @@ mod test_kept_rate {
assert_eq!(result.discarded_chars, 12);
}
+ #[test]
+ fn test_kept_rate_treats_unchanged_stale_text_as_context() {
+ let base = indoc! {"
+ a=fomr
+ b=old
+ "};
+ let candidate = indoc! {"
+ a=formula;
+ b=old
+ "};
+ let reference = indoc! {"
+ a=formula;
+ b=new
+ "};
+
+ let result = compute_kept_rate(base, candidate, reference);
+ let candidate_tokens = tokenize(candidate);
+
+ assert_eq!(result.candidate_new_chars, "formula".len() + ";".len());
+ assert_eq!(result.kept_chars, "formula".len() + ";".len());
+ assert_eq!(result.discarded_chars, 0);
+ assert_eq!(result.candidate_deleted_chars, "fomr".len());
+ assert_eq!(result.correctly_deleted_chars, "fomr".len());
+ assert!((result.kept_rate - 1.0).abs() < 1e-6);
+ assert!((result.recall_rate - (2.0 / 3.0)).abs() < 1e-6);
+
+ let old_index = candidate_tokens
+ .iter()
+ .position(|&token| token == "old")
+ .expect("old token not found");
+ assert_eq!(
+ result.token_annotations[old_index],
+ TokenAnnotation::Context
+ );
+ }
+
#[test]
fn test_annotations_rename() {
let base = " foo(old_name)\n";
@@ -514,7 +650,7 @@ mod test_kept_rate {
assert_eq!(result.token_annotations.len(), tokenize(candidate).len());
for (&token, &annotation) in tokenize(candidate).iter().zip(&result.token_annotations) {
- if token == "new_name" {
+ if matches!(token, "new" | "_" | "name") {
assert_eq!(annotation, TokenAnnotation::Kept);
} else {
assert_eq!(annotation, TokenAnnotation::Context);
@@ -524,9 +660,18 @@ mod test_kept_rate {
#[test]
fn test_annotations_eprintln_coloring() {
- let base = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n epr\n";
- let candidate = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"hello world!\");\n";
- let reference = " fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {\n eprintln!(\"\");\n";
+ let base = indoc! {"
+ fn example() {
+ epr
+ "};
+ let candidate = indoc! {r#"
+ fn example() {
+ eprintln!("hello world!");
+ "#};
+ let reference = indoc! {r#"
+ fn example() {
+ eprintln!("");
+ "#};
let result = compute_kept_rate(base, candidate, reference);
let candidate_tokens = tokenize(candidate);
@@ -0,0 +1,710 @@
+use std::env;
+use std::fmt::Write as _;
+use std::fs;
+use std::path::Path;
+use std::process;
+
+use edit_prediction_metrics::{
+ ClassificationMetrics, DeltaChrFMetrics, KeptRateResult, TokenAnnotation,
+ annotate_kept_rate_tokens, braces_disbalance, compute_kept_rate, count_patch_token_changes,
+ delta_chr_f, exact_lines_match, extract_changed_lines_from_diff,
+ has_isolated_whitespace_changes, is_editable_region_correct,
+};
+use serde::Deserialize;
+
+fn main() {
+ if let Err(error) = run() {
+ eprintln!("error: {error}");
+ process::exit(1);
+ }
+}
+
+fn run() -> Result<(), String> {
+ let args: Vec<String> = env::args().skip(1).collect();
+ if args.is_empty() {
+ print_usage();
+ return Err("missing arguments".to_string());
+ }
+
+ let input = CliInput::parse(&args)?;
+ let report = match input {
+ CliInput::Files {
+ base_path,
+ expected_patch_path,
+ actual_patch_path,
+ } => {
+ let base = fs::read_to_string(&base_path)
+ .map_err(|err| format!("failed to read {}: {err}", base_path.display()))?;
+ let expected_patch = fs::read_to_string(&expected_patch_path).map_err(|err| {
+ format!("failed to read {}: {err}", expected_patch_path.display())
+ })?;
+ let actual_patch = fs::read_to_string(&actual_patch_path)
+ .map_err(|err| format!("failed to read {}: {err}", actual_patch_path.display()))?;
+
+ let expected = apply_patch_to_excerpt(&base, &expected_patch, 0)?;
+ let actual = apply_patch_to_excerpt(&base, &actual_patch, 0)?;
+
+ EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
+ }
+ CliInput::Json {
+ json_path,
+ prediction_index,
+ } => {
+ let json = fs::read_to_string(&json_path)
+ .map_err(|err| format!("failed to read {}: {err}", json_path.display()))?;
+ let example: JsonExample = serde_json::from_str(&json)
+ .map_err(|err| format!("failed to parse {}: {err}", json_path.display()))?;
+
+ let base = example.prompt_inputs.cursor_excerpt;
+ let excerpt_start_row = example.prompt_inputs.excerpt_start_row;
+ let expected_patch = example
+ .expected_patches
+ .into_iter()
+ .next()
+ .ok_or_else(|| "JSON input is missing expected_patches[0]".to_string())?;
+ let actual_patch = example
+ .predictions
+ .into_iter()
+ .nth(prediction_index)
+ .ok_or_else(|| {
+ format!("JSON input does not contain predictions[{prediction_index}]")
+ })?
+ .actual_patch;
+
+ let expected = apply_patch_to_excerpt(&base, &expected_patch, excerpt_start_row)?;
+ let actual = apply_patch_to_excerpt(&base, &actual_patch, excerpt_start_row)?;
+
+ EvaluationReport::new(base, expected_patch, actual_patch, expected, actual)
+ }
+ };
+
+ print_report(&report);
+ Ok(())
+}
+
+fn print_usage() {
+ eprintln!(
+ "Usage:\n edit_prediction_metrics --base <base.txt> --expected-patch <expected.diff> --actual-patch <actual.diff>\n edit_prediction_metrics --json <example.json> [--prediction-index <n>]"
+ );
+}
+
+enum CliInput {
+ Files {
+ base_path: std::path::PathBuf,
+ expected_patch_path: std::path::PathBuf,
+ actual_patch_path: std::path::PathBuf,
+ },
+ Json {
+ json_path: std::path::PathBuf,
+ prediction_index: usize,
+ },
+}
+
+impl CliInput {
+ fn parse(args: &[String]) -> Result<Self, String> {
+ let mut base_path = None;
+ let mut expected_patch_path = None;
+ let mut actual_patch_path = None;
+ let mut json_path = None;
+ let mut prediction_index = 0usize;
+
+ let mut index = 0;
+ while index < args.len() {
+ match args[index].as_str() {
+ "--base" => {
+ index += 1;
+ base_path = Some(path_arg(args, index, "--base")?);
+ }
+ "--expected-patch" => {
+ index += 1;
+ expected_patch_path = Some(path_arg(args, index, "--expected-patch")?);
+ }
+ "--actual-patch" => {
+ index += 1;
+ actual_patch_path = Some(path_arg(args, index, "--actual-patch")?);
+ }
+ "--json" => {
+ index += 1;
+ json_path = Some(path_arg(args, index, "--json")?);
+ }
+ "--prediction-index" => {
+ index += 1;
+ let raw = string_arg(args, index, "--prediction-index")?;
+ prediction_index = raw.parse::<usize>().map_err(|err| {
+ format!("invalid value for --prediction-index ({raw}): {err}")
+ })?;
+ }
+ "--help" | "-h" => {
+ print_usage();
+ process::exit(0);
+ }
+ unknown => {
+ return Err(format!("unrecognized argument: {unknown}"));
+ }
+ }
+ index += 1;
+ }
+
+ if let Some(json_path) = json_path {
+ if base_path.is_some() || expected_patch_path.is_some() || actual_patch_path.is_some() {
+ return Err(
+ "--json cannot be combined with --base/--expected-patch/--actual-patch"
+ .to_string(),
+ );
+ }
+ return Ok(CliInput::Json {
+ json_path,
+ prediction_index,
+ });
+ }
+
+ match (base_path, expected_patch_path, actual_patch_path) {
+ (Some(base_path), Some(expected_patch_path), Some(actual_patch_path)) => {
+ Ok(CliInput::Files {
+ base_path,
+ expected_patch_path,
+ actual_patch_path,
+ })
+ }
+ _ => Err(
+ "expected either --json <file> or all of --base, --expected-patch, and --actual-patch"
+ .to_string(),
+ ),
+ }
+ }
+}
+
+fn path_arg(args: &[String], index: usize, flag: &str) -> Result<std::path::PathBuf, String> {
+ Ok(Path::new(string_arg(args, index, flag)?).to_path_buf())
+}
+
+fn string_arg<'a>(args: &'a [String], index: usize, flag: &str) -> Result<&'a str, String> {
+ args.get(index)
+ .map(|value| value.as_str())
+ .ok_or_else(|| format!("missing value for {flag}"))
+}
+
+#[derive(Debug)]
+struct EvaluationReport {
+ base: String,
+ expected: String,
+ actual: String,
+ kept_rate: KeptRateResult,
+ exact_lines: ClassificationMetrics,
+ delta_chr_f: DeltaChrFMetrics,
+ expected_changed_lines: usize,
+ actual_changed_lines: usize,
+ token_changes: edit_prediction_metrics::TokenChangeCounts,
+ isolated_whitespace_changes: bool,
+ editable_region_correct: bool,
+ expected_braces_disbalance: usize,
+ actual_braces_disbalance: usize,
+}
+
+impl EvaluationReport {
+ fn new(
+ base: String,
+ expected_patch: String,
+ actual_patch: String,
+ expected: String,
+ actual: String,
+ ) -> Self {
+ let kept_rate = compute_kept_rate(&base, &actual, &expected);
+ let exact_lines = exact_lines_match(&expected_patch, &actual_patch);
+ let delta_chr_f = delta_chr_f(&base, &expected, &actual);
+ let expected_changed_lines = extract_changed_lines_from_diff(&expected_patch)
+ .values()
+ .sum();
+ let actual_changed_lines = extract_changed_lines_from_diff(&actual_patch)
+ .values()
+ .sum();
+ let token_changes = count_patch_token_changes(&actual_patch);
+ let isolated_whitespace_changes = has_isolated_whitespace_changes(&actual_patch, None);
+ let editable_region_correct = is_editable_region_correct(&actual_patch);
+ let expected_braces_disbalance = braces_disbalance(&expected);
+ let actual_braces_disbalance = braces_disbalance(&actual);
+
+ Self {
+ base,
+ expected,
+ actual,
+ kept_rate,
+ exact_lines,
+ delta_chr_f,
+ expected_changed_lines,
+ actual_changed_lines,
+ token_changes,
+ isolated_whitespace_changes,
+ editable_region_correct,
+ expected_braces_disbalance,
+ actual_braces_disbalance,
+ }
+ }
+}
+
+fn print_report(report: &EvaluationReport) {
+ println!("Metrics");
+ println!("=======");
+ println!("kept_rate: {:.6}", report.kept_rate.kept_rate);
+ println!("kept_rate_recall: {:.6}", report.kept_rate.recall_rate);
+ println!("delta_chr_f: {:.6}", report.delta_chr_f.score);
+ println!("delta_chr_f_precision: {:.6}", report.delta_chr_f.precision);
+ println!("delta_chr_f_recall: {:.6}", report.delta_chr_f.recall);
+ println!("delta_chr_f_beta: {:.6}", report.delta_chr_f.beta);
+ println!();
+
+ println!("Exact line match");
+ println!("----------------");
+ println!("true_positives: {}", report.exact_lines.true_positives);
+ println!("false_positives: {}", report.exact_lines.false_positives);
+ println!("false_negatives: {}", report.exact_lines.false_negatives);
+ println!("precision: {:.6}", report.exact_lines.precision());
+ println!("recall: {:.6}", report.exact_lines.recall());
+ println!("f1: {:.6}", report.exact_lines.f1());
+ println!("expected_changed_lines: {}", report.expected_changed_lines);
+ println!("actual_changed_lines: {}", report.actual_changed_lines);
+ println!();
+
+ println!("Patch structure");
+ println!("---------------");
+ println!("inserted_tokens: {}", report.token_changes.inserted_tokens);
+ println!("deleted_tokens: {}", report.token_changes.deleted_tokens);
+ println!(
+ "isolated_whitespace_changes: {}",
+ report.isolated_whitespace_changes
+ );
+ println!(
+ "editable_region_correct: {}",
+ report.editable_region_correct
+ );
+ println!();
+
+ println!("Final text checks");
+ println!("-----------------");
+ println!(
+ "expected_braces_disbalance: {}",
+ report.expected_braces_disbalance
+ );
+ println!(
+ "actual_braces_disbalance: {}",
+ report.actual_braces_disbalance
+ );
+ println!();
+
+ println!("Kept-rate breakdown");
+ println!("-------------------");
+ println!(
+ "candidate_new_chars: {}",
+ report.kept_rate.candidate_new_chars
+ );
+ println!(
+ "reference_new_chars: {}",
+ report.kept_rate.reference_new_chars
+ );
+ println!(
+ "candidate_deleted_chars: {}",
+ report.kept_rate.candidate_deleted_chars
+ );
+ println!(
+ "reference_deleted_chars: {}",
+ report.kept_rate.reference_deleted_chars
+ );
+ println!("kept_chars: {}", report.kept_rate.kept_chars);
+ println!(
+ "correctly_deleted_chars: {}",
+ report.kept_rate.correctly_deleted_chars
+ );
+ println!("discarded_chars: {}", report.kept_rate.discarded_chars);
+ println!("context_chars: {}", report.kept_rate.context_chars);
+ println!();
+
+ print_kept_rate_explanation(&report.base, &report.actual, &report.expected);
+}
+
+fn print_kept_rate_explanation(base: &str, actual: &str, expected: &str) {
+ println!("Kept-rate explanation");
+ println!("---------------------");
+ println!("Legend: context = default, kept = green background, discarded = red background");
+ println!();
+
+ let annotated = annotate_kept_rate_tokens(base, actual, expected);
+ println!("Actual final text with token annotations:");
+ println!("{}", render_annotated_tokens(&annotated));
+ println!();
+}
+
+fn render_annotated_tokens(tokens: &[edit_prediction_metrics::AnnotatedToken]) -> String {
+ const RESET: &str = "\x1b[0m";
+ const KEPT_STYLE: &str = "\x1b[30;42m";
+ const DISCARDED_STYLE: &str = "\x1b[30;41m";
+
+ let mut rendered = String::new();
+ for token in tokens {
+ let style = match token.annotation {
+ TokenAnnotation::Context => "",
+ TokenAnnotation::Kept => KEPT_STYLE,
+ TokenAnnotation::Discarded => DISCARDED_STYLE,
+ };
+
+ if style.is_empty() {
+ rendered.push_str(&visualize_whitespace(&token.token));
+ } else {
+ rendered.push_str(style);
+ rendered.push_str(&visualize_whitespace(&token.token));
+ rendered.push_str(RESET);
+ }
+ }
+ rendered
+}
+
+fn visualize_whitespace(token: &str) -> String {
+ let mut rendered = String::new();
+ for ch in token.chars() {
+ match ch {
+ ' ' => rendered.push('ยท'),
+ '\t' => rendered.push('โฅ'),
+ '\n' => rendered.push_str("โต\n"),
+ _ => rendered.push(ch),
+ }
+ }
+ rendered
+}
+
+#[derive(Debug, Deserialize)]
+struct JsonExample {
+ prompt_inputs: PromptInputs,
+ expected_patches: Vec<String>,
+ predictions: Vec<Prediction>,
+}
+
+#[derive(Debug, Deserialize)]
+struct PromptInputs {
+ cursor_excerpt: String,
+ excerpt_start_row: u32,
+}
+
+#[derive(Debug, Deserialize)]
+struct Prediction {
+ actual_patch: String,
+}
+
+#[derive(Debug, Clone)]
+struct ParsedHunk {
+ old_start: u32,
+ lines: Vec<HunkLine>,
+}
+
+#[derive(Debug, Clone)]
+enum HunkLine {
+ Context(String),
+ Addition(String),
+ Deletion(String),
+}
+
+fn apply_patch_to_excerpt(
+ base: &str,
+ patch: &str,
+ excerpt_start_row: u32,
+) -> Result<String, String> {
+ let hunks = parse_diff_hunks(patch);
+
+ let result = try_apply_hunks(base, &hunks, excerpt_start_row);
+
+ // Predicted patches may use excerpt-relative line numbers instead of
+ // file-global ones. When all hunks fall outside the excerpt window the
+ // result is identical to the base text. Retry with a zero offset so the
+ // line numbers are interpreted relative to the excerpt.
+ if excerpt_start_row > 0 && !hunks.is_empty() {
+ let should_retry = match &result {
+ Ok(text) => text == base,
+ Err(_) => true,
+ };
+
+ if should_retry {
+ let fallback = try_apply_hunks(base, &hunks, 0);
+ if matches!(&fallback, Ok(text) if text != base) {
+ return fallback;
+ }
+ }
+ }
+
+ result
+}
+
+fn try_apply_hunks(
+ base: &str,
+ hunks: &[ParsedHunk],
+ excerpt_start_row: u32,
+) -> Result<String, String> {
+ let base_has_trailing_newline = base.ends_with('\n');
+ let mut lines = split_preserving_final_empty_line(base);
+ let original_line_count = lines.len() as u32;
+
+ let excerpt_end_row = excerpt_start_row + original_line_count;
+ let mut line_delta: i64 = 0;
+
+ for hunk in hunks {
+ let filtered = match filter_hunk_to_excerpt(hunk, excerpt_start_row, excerpt_end_row) {
+ Some(filtered) => filtered,
+ None => continue,
+ };
+
+ let local_start = filtered.old_start.saturating_sub(excerpt_start_row) as i64 + line_delta;
+ if local_start < 0 {
+ return Err(format!(
+ "patch application moved before excerpt start at source row {}",
+ filtered.old_start
+ ));
+ }
+ let local_start = local_start as usize;
+
+ if local_start > lines.len() {
+ return Err(format!(
+ "patch application starts past excerpt end at local line {}",
+ local_start + 1
+ ));
+ }
+
+ let old_len = filtered
+ .lines
+ .iter()
+ .filter(|line| !matches!(line, HunkLine::Addition(_)))
+ .count();
+ let new_len = filtered
+ .lines
+ .iter()
+ .filter(|line| !matches!(line, HunkLine::Deletion(_)))
+ .count();
+
+ let old_segment: Vec<&str> = filtered
+ .lines
+ .iter()
+ .filter_map(|line| match line {
+ HunkLine::Context(text) | HunkLine::Deletion(text) => Some(text.as_str()),
+ HunkLine::Addition(_) => None,
+ })
+ .collect();
+
+ let new_segment: Vec<String> = filtered
+ .lines
+ .iter()
+ .filter_map(|line| match line {
+ HunkLine::Context(text) | HunkLine::Addition(text) => Some(text.clone()),
+ HunkLine::Deletion(_) => None,
+ })
+ .collect();
+
+ if local_start + old_len > lines.len() {
+ return Err(format!(
+ "patch application exceeds excerpt bounds near source row {}",
+ filtered.old_start
+ ));
+ }
+
+ let current_segment: Vec<&str> = lines[local_start..local_start + old_len]
+ .iter()
+ .map(String::as_str)
+ .collect();
+
+ if current_segment != old_segment {
+ let mut details = String::new();
+ let _ = write!(
+ details,
+ "patch context mismatch near source row {}: expected {:?}, found {:?}",
+ filtered.old_start, old_segment, current_segment
+ );
+ return Err(details);
+ }
+
+ lines.splice(local_start..local_start + old_len, new_segment);
+ line_delta += new_len as i64 - old_len as i64;
+ }
+
+ Ok(join_lines(&lines, base_has_trailing_newline))
+}
+
+fn split_preserving_final_empty_line(text: &str) -> Vec<String> {
+ let mut lines: Vec<String> = text.lines().map(ToString::to_string).collect();
+ if text.ends_with('\n') {
+ if lines.last().is_some_and(|line| !line.is_empty()) || lines.is_empty() {
+ lines.push(String::new());
+ }
+ }
+ lines
+}
+
+fn join_lines(lines: &[String], had_trailing_newline: bool) -> String {
+ if lines.is_empty() {
+ return String::new();
+ }
+
+ let mut joined = lines.join("\n");
+ if had_trailing_newline && !joined.ends_with('\n') {
+ joined.push('\n');
+ }
+ if !had_trailing_newline && joined.ends_with('\n') {
+ joined.pop();
+ }
+ joined
+}
+
+fn filter_hunk_to_excerpt(
+ hunk: &ParsedHunk,
+ excerpt_start_row: u32,
+ excerpt_end_row: u32,
+) -> Option<ParsedHunk> {
+ let mut filtered_lines = Vec::new();
+ let mut current_old_row = hunk.old_start.saturating_sub(1);
+ let mut filtered_old_start = None;
+ let mut has_overlap = false;
+
+ for line in &hunk.lines {
+ match line {
+ HunkLine::Context(text) => {
+ let in_excerpt =
+ current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
+ if in_excerpt {
+ filtered_old_start.get_or_insert(current_old_row);
+ filtered_lines.push(HunkLine::Context(text.clone()));
+ has_overlap = true;
+ }
+ current_old_row += 1;
+ }
+ HunkLine::Deletion(text) => {
+ let in_excerpt =
+ current_old_row >= excerpt_start_row && current_old_row < excerpt_end_row;
+ if in_excerpt {
+ filtered_old_start.get_or_insert(current_old_row);
+ filtered_lines.push(HunkLine::Deletion(text.clone()));
+ has_overlap = true;
+ }
+ current_old_row += 1;
+ }
+ HunkLine::Addition(text) => {
+ let insertion_in_excerpt =
+ current_old_row >= excerpt_start_row && current_old_row <= excerpt_end_row;
+ if insertion_in_excerpt {
+ filtered_old_start.get_or_insert(current_old_row);
+ filtered_lines.push(HunkLine::Addition(text.clone()));
+ has_overlap = true;
+ }
+ }
+ }
+ }
+
+ if !has_overlap {
+ return None;
+ }
+
+ Some(ParsedHunk {
+ old_start: filtered_old_start.unwrap_or(excerpt_start_row),
+ lines: filtered_lines,
+ })
+}
+
+fn parse_diff_hunks(diff: &str) -> Vec<ParsedHunk> {
+ let mut hunks = Vec::new();
+ let mut current_hunk: Option<ParsedHunk> = None;
+
+ for line in diff.lines() {
+ if let Some((old_start, old_count, _new_start, _new_count)) = parse_hunk_header(line) {
+ if let Some(hunk) = current_hunk.take() {
+ hunks.push(hunk);
+ }
+ let _ = old_count;
+ current_hunk = Some(ParsedHunk {
+ old_start,
+ lines: Vec::new(),
+ });
+ continue;
+ }
+
+ let Some(hunk) = current_hunk.as_mut() else {
+ continue;
+ };
+
+ if let Some(text) = line.strip_prefix('+') {
+ if !line.starts_with("+++") {
+ hunk.lines.push(HunkLine::Addition(text.to_string()));
+ }
+ } else if let Some(text) = line.strip_prefix('-') {
+ if !line.starts_with("---") {
+ hunk.lines.push(HunkLine::Deletion(text.to_string()));
+ }
+ } else if let Some(text) = line.strip_prefix(' ') {
+ hunk.lines.push(HunkLine::Context(text.to_string()));
+ } else if line.is_empty() {
+ hunk.lines.push(HunkLine::Context(String::new()));
+ }
+ }
+
+ if let Some(hunk) = current_hunk {
+ hunks.push(hunk);
+ }
+
+ hunks
+}
+
+fn parse_hunk_header(line: &str) -> Option<(u32, u32, u32, u32)> {
+ let line = line.strip_prefix("@@ -")?;
+ let (old_part, rest) = line.split_once(' ')?;
+ let rest = rest.strip_prefix('+')?;
+ let (new_part, _) = rest.split_once(" @@")?;
+
+ let (old_start, old_count) = parse_hunk_range(old_part)?;
+ let (new_start, new_count) = parse_hunk_range(new_part)?;
+ Some((old_start, old_count, new_start, new_count))
+}
+
+fn parse_hunk_range(part: &str) -> Option<(u32, u32)> {
+ if let Some((start, count)) = part.split_once(',') {
+ Some((start.parse().ok()?, count.parse().ok()?))
+ } else {
+ Some((part.parse().ok()?, 1))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn applies_patch_in_file_mode() {
+ let base = "fn main() {\n println!(\"hello\");\n}\n";
+ let patch = "@@ -1,3 +1,3 @@\n fn main() {\n- println!(\"hello\");\n+ println!(\"world\");\n }\n";
+
+ let actual = apply_patch_to_excerpt(base, patch, 0).unwrap();
+ assert_eq!(actual, "fn main() {\n println!(\"world\");\n}\n");
+ }
+
+ #[test]
+ fn applies_patch_in_json_excerpt_mode() {
+ let base = "b\nc\nd\n";
+ let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
+
+ let actual = apply_patch_to_excerpt(base, patch, 1).unwrap();
+ assert_eq!(actual, "x\ny\nd\n");
+ }
+
+ #[test]
+ fn applies_patch_with_excerpt_relative_line_numbers() {
+ let base = "a\nb\nc\nd\n";
+ // Patch uses excerpt-relative line numbers (line 2 of excerpt)
+ // even though the excerpt starts at file row 100.
+ let patch = "@@ -2,2 +2,2 @@\n-b\n-c\n+x\n+y\n";
+
+ let actual = apply_patch_to_excerpt(base, patch, 100).unwrap();
+ assert_eq!(actual, "a\nx\ny\nd\n");
+ }
+
+ #[test]
+ fn prefers_file_global_line_numbers_over_excerpt_relative() {
+ let base = "a\nb\nc\n";
+ // Patch uses file-global line numbers: excerpt starts at row 5,
+ // hunk targets line 6 (1-based) = row 5 (0-based) = first line.
+ let patch = "@@ -6,2 +6,2 @@\n-a\n-b\n+x\n+y\n";
+
+ let actual = apply_patch_to_excerpt(base, patch, 5).unwrap();
+ assert_eq!(actual, "x\ny\nc\n");
+ }
+}
@@ -687,6 +687,35 @@ fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
.collect()
}
+/// Reconstruct old and new text from a unified diff.
+///
+/// Context and deletion lines form the old text; context and addition
+/// lines form the new text. Returns `(old_text, new_text)`.
+pub fn reconstruct_texts_from_diff(patch_str: &str) -> (String, String) {
+ let patch = Patch::parse_unified_diff(patch_str);
+ let mut old_lines: Vec<&str> = Vec::new();
+ let mut new_lines: Vec<&str> = Vec::new();
+
+ for hunk in &patch.hunks {
+ for line in &hunk.lines {
+ match line {
+ PatchLine::Context(content) => {
+ old_lines.push(content);
+ new_lines.push(content);
+ }
+ PatchLine::Deletion(content) => {
+ old_lines.push(content);
+ }
+ PatchLine::Addition(content) => {
+ new_lines.push(content);
+ }
+ PatchLine::Garbage(_) => {}
+ }
+ }
+ }
+
+ (old_lines.join("\n"), new_lines.join("\n"))
+}
#[derive(Debug, Default, Clone)]
struct Patch {
hunks: Vec<Hunk>,
@@ -1,33 +1,158 @@
-fn char_class(character: char) -> u8 {
- if character.is_alphanumeric() || character == '_' {
- 0
+use std::{iter::Peekable, str::CharIndices};
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+enum CharClass {
+ Identifier,
+ Newline,
+ Whitespace,
+ Punctuation,
+}
+
+const MULTI_CHAR_PUNCTUATION: &[&str] = &[
+ ">>>=", "<<=", ">>=", "...", "..=", "??=", "**=", ">>>", "::", "->", "=>", "==", "!=", "<=",
+ ">=", "&&", "||", "<<", ">>", "..", "+=", "-=", "*=", "/=", "%=", "&=", "|=", "^=", "++", "--",
+ "**", "??", "?.", ":=", "<-", "//", "/*", "*/",
+];
+
+fn char_class(character: char) -> CharClass {
+ if character == '\n' || character == '\r' {
+ CharClass::Newline
} else if character.is_whitespace() {
- 1
+ CharClass::Whitespace
+ } else if character.is_alphanumeric() || character == '_' {
+ CharClass::Identifier
} else {
- 2
+ CharClass::Punctuation
}
}
+fn is_identifier_boundary(previous: char, current: char, next: Option<char>) -> bool {
+ (current.is_uppercase() && (previous.is_lowercase() || previous.is_numeric()))
+ || (current.is_uppercase()
+ && previous.is_uppercase()
+ && next.is_some_and(|next| next.is_lowercase()))
+}
+
+fn push_identifier_tokens<'a>(identifier: &'a str, tokens: &mut Vec<&'a str>) {
+ let characters: Vec<(usize, char)> = identifier.char_indices().collect();
+ let mut segment_start = 0;
+ let mut index = 0;
+
+ while index < characters.len() {
+ let (byte_index, character) = characters[index];
+
+ if character == '_' {
+ if segment_start < byte_index {
+ tokens.push(&identifier[segment_start..byte_index]);
+ }
+
+ let mut underscore_end = byte_index + character.len_utf8();
+ index += 1;
+
+ while index < characters.len() && characters[index].1 == '_' {
+ underscore_end = characters[index].0 + characters[index].1.len_utf8();
+ index += 1;
+ }
+
+ tokens.push(&identifier[byte_index..underscore_end]);
+ segment_start = underscore_end;
+ continue;
+ }
+
+ if byte_index > segment_start {
+ let previous = characters[index - 1].1;
+ let next = characters.get(index + 1).map(|(_, character)| *character);
+
+ if is_identifier_boundary(previous, character, next) {
+ tokens.push(&identifier[segment_start..byte_index]);
+ segment_start = byte_index;
+ }
+ }
+
+ index += 1;
+ }
+
+ if segment_start < identifier.len() {
+ tokens.push(&identifier[segment_start..]);
+ }
+}
+
+fn push_punctuation_token<'a>(
+ text: &'a str,
+ start: usize,
+ character: char,
+ characters: &mut Peekable<CharIndices<'a>>,
+ tokens: &mut Vec<&'a str>,
+) {
+ let remaining = &text[start..];
+
+ for punctuation in MULTI_CHAR_PUNCTUATION {
+ if remaining.starts_with(punctuation) {
+ for _ in punctuation.chars().skip(1) {
+ characters.next();
+ }
+
+ tokens.push(&remaining[..punctuation.len()]);
+ return;
+ }
+ }
+
+ let end = start + character.len_utf8();
+ tokens.push(&text[start..end]);
+}
+
pub(crate) fn tokenize(text: &str) -> Vec<&str> {
let mut tokens = Vec::new();
let mut characters = text.char_indices().peekable();
while let Some((start, character)) = characters.next() {
- let class = char_class(character);
- if class == 2 {
- tokens.push(&text[start..start + character.len_utf8()]);
- continue;
- }
+ match char_class(character) {
+ CharClass::Identifier => {
+ let mut end = start + character.len_utf8();
+
+ while let Some(&(next_start, next_character)) = characters.peek() {
+ if char_class(next_character) != CharClass::Identifier {
+ break;
+ }
+
+ end = next_start + next_character.len_utf8();
+ characters.next();
+ }
+
+ push_identifier_tokens(&text[start..end], &mut tokens);
+ }
+ CharClass::Newline => {
+ let mut end = start + character.len_utf8();
+
+ while let Some(&(next_start, next_character)) = characters.peek() {
+ if char_class(next_character) != CharClass::Newline {
+ break;
+ }
+
+ end = next_start + next_character.len_utf8();
+ characters.next();
+ }
- let mut end = start + character.len_utf8();
- while let Some(&(_, next_character)) = characters.peek() {
- if char_class(next_character) != class {
- break;
+ tokens.push(&text[start..end]);
+ }
+ CharClass::Whitespace => {
+ let mut end = start + character.len_utf8();
+
+ while let Some(&(next_start, next_character)) = characters.peek() {
+ if char_class(next_character) != CharClass::Whitespace {
+ break;
+ }
+
+ end = next_start + next_character.len_utf8();
+ characters.next();
+ }
+
+ tokens.push(&text[start..end]);
+ }
+ CharClass::Punctuation => {
+ push_punctuation_token(text, start, character, &mut characters, &mut tokens);
}
- end += next_character.len_utf8();
- characters.next();
}
- tokens.push(&text[start..end]);
}
tokens
@@ -38,17 +163,58 @@ mod tests {
use super::tokenize;
#[test]
- fn tokenizes_code_like_text() {
+ fn tokenizes_code() {
assert_eq!(tokenize("hello world"), vec!["hello", " ", "world"]);
assert_eq!(
tokenize("foo_bar123 + baz"),
- vec!["foo_bar123", " ", "+", " ", "baz"]
+ vec!["foo", "_", "bar123", " ", "+", " ", "baz"]
);
assert_eq!(
tokenize("print(\"hello\")"),
vec!["print", "(", "\"", "hello", "\"", ")"]
);
- assert_eq!(tokenize("hello_world"), vec!["hello_world"]);
+ assert_eq!(tokenize("hello_world"), vec!["hello", "_", "world"]);
assert_eq!(tokenize("fn();"), vec!["fn", "(", ")", ";"]);
}
+
+ #[test]
+ fn tokenizes_identifier_case_styles() {
+ assert_eq!(
+ tokenize("camelCase PascalCase snake_case"),
+ vec![
+ "camel", "Case", " ", "Pascal", "Case", " ", "snake", "_", "case"
+ ]
+ );
+ assert_eq!(
+ tokenize("myHTTPServer __private_value foo__bar"),
+ vec![
+ "my", "HTTP", "Server", " ", "__", "private", "_", "value", " ", "foo", "__", "bar"
+ ]
+ );
+ assert_eq!(
+ tokenize("XMLHttpRequest Version2Update"),
+ vec!["XML", "Http", "Request", " ", "Version2", "Update"]
+ );
+ }
+
+ #[test]
+ fn tokenizes_grouped_punctuation() {
+ assert_eq!(
+ tokenize("a::b -> c != d ..= e"),
+ vec![
+ "a", "::", "b", " ", "->", " ", "c", " ", "!=", " ", "d", " ", "..=", " ", "e"
+ ]
+ );
+ assert_eq!(
+ tokenize("foo?.bar ?? baz"),
+ vec!["foo", "?.", "bar", " ", "??", " ", "baz"]
+ );
+ }
+
+ #[test]
+ fn tokenize_whitespace_runs() {
+ assert_eq!(tokenize(" "), vec![" "]);
+ assert_eq!(tokenize(" \n foo"), vec![" ", "\n", " ", "foo"]);
+ assert_eq!(tokenize("\r\n\nfoo"), vec!["\r\n\n", "foo"]);
+ }
}
@@ -63,6 +63,7 @@ extend-exclude = [
"crates/gpui_macos/src/dispatcher.rs",
# Tests contain partially incomplete words (by design)
"crates/edit_prediction_cli/src/split_commit.rs",
+ "crates/edit_prediction_metrics/src/kept_rate.rs",
# Eval examples contain intentionally partial words (e.g. "secur" for "secure")
"crates/edit_prediction_cli/evals/",
# Tests contain `baหr` that cause `"ba" should be "by" or "be".`-like false-positives