@@ -1,8 +1,6 @@
-use anyhow::{Context as _, Result, anyhow};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
-use std::ops::Range;
-use std::path::Path;
-use std::sync::Arc;
+use anyhow::{Context as _, Result};
+use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
+use std::{cmp, ops::Range, path::Path, sync::Arc};
pub async fn parse_xml_edits<'a>(
input: &'a str,
@@ -40,128 +38,76 @@ async fn parse_xml_edits_inner<'a>(
while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? {
let new_text_tag =
parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?;
- edits.extend(resolve_new_text_old_text_in_buffer(
- new_text_tag.body,
- old_text_tag.body,
- buffer,
- context_ranges,
- )?);
+ let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?;
+ let old_text = buffer
+ .text_for_range(match_range.clone())
+ .collect::<String>();
+ let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body);
+ edits.extend(
+ edits_within_hunk
+ .into_iter()
+ .map(move |(inner_range, inner_text)| {
+ (
+ buffer.anchor_after(match_range.start + inner_range.start)
+ ..buffer.anchor_before(match_range.start + inner_range.end),
+ inner_text,
+ )
+ }),
+ );
}
Ok((buffer, edits))
}
-fn resolve_new_text_old_text_in_buffer(
- new_text: &str,
+fn fuzzy_match_in_ranges(
old_text: &str,
- buffer: &TextBufferSnapshot,
- ranges: &[Range<Anchor>],
-) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
- let context_offset = if old_text.is_empty() {
- Ok(0)
- } else {
- let mut offset = None;
- for range in ranges {
- let range = range.to_offset(buffer);
- let text = buffer.text_for_range(range.clone()).collect::<String>();
- for (match_offset, _) in text.match_indices(old_text) {
- if let Some(offset) = offset {
- let offset_match_point = buffer.offset_to_point(offset);
- let second_match_point = buffer.offset_to_point(range.start + match_offset);
- anyhow::bail!(
- "old_text is not unique enough:\n{}\nFound at {:?} and {:?}",
- old_text,
- offset_match_point,
- second_match_point
- );
+ buffer: &BufferSnapshot,
+ context_ranges: &[Range<Anchor>],
+) -> Result<Range<usize>> {
+ let mut state = FuzzyMatcher::new(buffer, old_text);
+ let mut best_match = None;
+ let mut tie_match_range = None;
+
+ for range in context_ranges {
+ let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
+ match (best_match_cost, state.match_range(range.to_offset(buffer))) {
+ (Some(lowest_cost), Some((new_cost, new_range))) => {
+ if new_cost == lowest_cost {
+ tie_match_range = Some(new_range);
+ } else if new_cost < lowest_cost {
+ tie_match_range.take();
+ best_match = Some((new_cost, new_range));
}
- offset = Some(range.start + match_offset);
}
- }
- offset.ok_or_else(|| {
- #[cfg(any(debug_assertions, feature = "eval-support"))]
- if let Some(closest_match) = closest_old_text_match(buffer, old_text) {
- log::info!(
- "Closest `old_text` match: {}",
- pretty_assertions::StrComparison::new(old_text, &closest_match)
- )
+ (None, Some(new_match)) => {
+ best_match = Some(new_match);
}
- anyhow!("Failed to match old_text:\n{}", old_text)
- })
- }?;
-
- let edits_within_hunk = language::text_diff(&old_text, &new_text);
- Ok(edits_within_hunk
- .into_iter()
- .map(move |(inner_range, inner_text)| {
- (
- buffer.anchor_after(context_offset + inner_range.start)
- ..buffer.anchor_before(context_offset + inner_range.end),
- inner_text,
- )
- }))
-}
-
-#[cfg(any(debug_assertions, feature = "eval-support"))]
-fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option<String> {
- let buffer_text = buffer.text();
- let len = old_text.len();
-
- if len == 0 || buffer_text.len() < len {
- return None;
+ (None, None) | (Some(_), None) => {}
+ };
}
- let mut min_score = usize::MAX;
- let mut min_start = 0;
-
- let old_text_bytes = old_text.as_bytes();
- let old_alpha_count = old_text_bytes
- .iter()
- .filter(|&&b| b.is_ascii_alphanumeric())
- .count();
-
- let old_line_count = old_text.lines().count();
-
- let mut cursor = 0;
-
- while cursor + len <= buffer_text.len() {
- let candidate = &buffer_text[cursor..cursor + len];
- let candidate_bytes = candidate.as_bytes();
-
- if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 {
- cursor += 1;
- continue;
- }
-
- let candidate_alpha_count = candidate_bytes
- .iter()
- .filter(|&&b| b.is_ascii_alphanumeric())
- .count();
-
- // If alphanumeric character count differs by more than 30%, skip
- if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 {
- cursor += 1;
- continue;
- }
-
- let score = strsim::levenshtein(candidate, old_text);
- if score < min_score {
- min_score = score;
- min_start = cursor;
-
- if min_score <= len / 10 {
- break;
- }
+ if let Some((_, best_match_range)) = best_match {
+ if let Some(tie_match_range) = tie_match_range {
+ anyhow::bail!(
+ "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
+ best_match_range.clone(),
+ buffer.text_for_range(best_match_range).collect::<String>(),
+ tie_match_range.clone(),
+ buffer.text_for_range(tie_match_range).collect::<String>()
+ );
}
-
- cursor += 1;
+ return Ok(best_match_range);
}
- if min_score != usize::MAX {
- Some(buffer_text[min_start..min_start + len].to_string())
- } else {
- None
- }
+ anyhow::bail!(
+ "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
+ old_text,
+ context_ranges
+ .iter()
+ .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
+ .collect::<Vec<String>>()
+ .join("```\n```")
+ );
}
struct ParsedTag<'a> {
@@ -187,10 +133,218 @@ fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result<Option<ParsedTag<'a>>
.with_context(|| format!("no `{close_tag}` tag"))?;
let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix];
let body = body.strip_prefix('\n').unwrap_or(body);
+ let body = body.strip_suffix('\n').unwrap_or(body);
*input = &input[end_ix + close_tag.len()..];
Ok(Some(ParsedTag { attributes, body }))
}
+const REPLACEMENT_COST: u32 = 1;
+const INSERTION_COST: u32 = 3;
+const DELETION_COST: u32 = 10;
+
+/// A fuzzy matcher that can process text chunks incrementally
+/// and return the best match found so far at each step.
+struct FuzzyMatcher<'a> {
+ snapshot: &'a BufferSnapshot,
+ query_lines: Vec<&'a str>,
+ matrix: SearchMatrix,
+}
+
+impl<'a> FuzzyMatcher<'a> {
+ fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
+ let query_lines = old_text.lines().collect();
+ Self {
+ snapshot,
+ query_lines,
+ matrix: SearchMatrix::new(0),
+ }
+ }
+
+ fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
+ let point_range = range.to_point(&self.snapshot);
+ let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
+
+ self.matrix
+ .reset(self.query_lines.len() + 1, buffer_line_count + 1);
+ let query_line_count = self.query_lines.len();
+
+ for row in 0..query_line_count {
+ let query_line = self.query_lines[row].trim();
+ let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
+
+ self.matrix.set(
+ row + 1,
+ 0,
+ SearchState::new(leading_deletion_cost, SearchDirection::Up),
+ );
+
+ let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
+
+ let mut col = 0;
+ while let Some(buffer_line) = buffer_lines.next() {
+ let buffer_line = buffer_line.trim();
+ let up = SearchState::new(
+ self.matrix
+ .get(row, col + 1)
+ .cost
+ .saturating_add(DELETION_COST),
+ SearchDirection::Up,
+ );
+ let left = SearchState::new(
+ self.matrix
+ .get(row + 1, col)
+ .cost
+ .saturating_add(INSERTION_COST),
+ SearchDirection::Left,
+ );
+ let diagonal = SearchState::new(
+ if query_line == buffer_line {
+ self.matrix.get(row, col).cost
+ } else if fuzzy_eq(query_line, buffer_line) {
+ self.matrix.get(row, col).cost + REPLACEMENT_COST
+ } else {
+ self.matrix
+ .get(row, col)
+ .cost
+ .saturating_add(DELETION_COST + INSERTION_COST)
+ },
+ SearchDirection::Diagonal,
+ );
+ self.matrix
+ .set(row + 1, col + 1, up.min(left).min(diagonal));
+ col += 1;
+ }
+ }
+
+ // Find all matches with the best cost
+ let mut best_cost = u32::MAX;
+ let mut matches_with_best_cost = Vec::new();
+
+ for col in 1..=buffer_line_count {
+ let cost = self.matrix.get(query_line_count, col).cost;
+ if cost < best_cost {
+ best_cost = cost;
+ matches_with_best_cost.clear();
+ matches_with_best_cost.push(col as u32);
+ } else if cost == best_cost {
+ matches_with_best_cost.push(col as u32);
+ }
+ }
+
+ // Find ranges for the matches
+ for &match_end_col in &matches_with_best_cost {
+ let mut matched_lines = 0;
+ let mut query_row = query_line_count;
+ let mut match_start_col = match_end_col;
+ while query_row > 0 && match_start_col > 0 {
+ let current = self.matrix.get(query_row, match_start_col as usize);
+ match current.direction {
+ SearchDirection::Diagonal => {
+ query_row -= 1;
+ match_start_col -= 1;
+ matched_lines += 1;
+ }
+ SearchDirection::Up => {
+ query_row -= 1;
+ }
+ SearchDirection::Left => {
+ match_start_col -= 1;
+ }
+ }
+ }
+
+ let buffer_row_start = match_start_col + point_range.start.row;
+ let buffer_row_end = match_end_col + point_range.start.row;
+
+ let matched_buffer_row_count = buffer_row_end - buffer_row_start;
+ let matched_ratio = matched_lines as f32
+ / (matched_buffer_row_count as f32).max(query_line_count as f32);
+ if matched_ratio >= 0.8 {
+ let buffer_start_ix = self
+ .snapshot
+ .point_to_offset(Point::new(buffer_row_start, 0));
+ let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
+ buffer_row_end - 1,
+ self.snapshot.line_len(buffer_row_end - 1),
+ ));
+ return Some((best_cost, buffer_start_ix..buffer_end_ix));
+ }
+ }
+
+ None
+ }
+}
+
+fn fuzzy_eq(left: &str, right: &str) -> bool {
+ const THRESHOLD: f64 = 0.8;
+
+ let min_levenshtein = left.len().abs_diff(right.len());
+ let min_normalized_levenshtein =
+ 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
+ if min_normalized_levenshtein < THRESHOLD {
+ return false;
+ }
+
+ strsim::normalized_levenshtein(left, right) >= THRESHOLD
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+enum SearchDirection {
+ Up,
+ Left,
+ Diagonal,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+struct SearchState {
+ cost: u32,
+ direction: SearchDirection,
+}
+
+impl SearchState {
+ fn new(cost: u32, direction: SearchDirection) -> Self {
+ Self { cost, direction }
+ }
+}
+
+struct SearchMatrix {
+ cols: usize,
+ rows: usize,
+ data: Vec<SearchState>,
+}
+
+impl SearchMatrix {
+ fn new(cols: usize) -> Self {
+ SearchMatrix {
+ cols,
+ rows: 0,
+ data: Vec::new(),
+ }
+ }
+
+ fn reset(&mut self, rows: usize, cols: usize) {
+ self.rows = rows;
+ self.cols = cols;
+ self.data
+ .fill(SearchState::new(0, SearchDirection::Diagonal));
+ self.data.resize(
+ self.rows * self.cols,
+ SearchState::new(0, SearchDirection::Diagonal),
+ );
+ }
+
+ fn get(&self, row: usize, col: usize) -> SearchState {
+ debug_assert!(row < self.rows);
+ debug_assert!(col < self.cols);
+ self.data[row * self.cols + col]
+ }
+
+ fn set(&mut self, row: usize, col: usize, state: SearchState) {
+ debug_assert!(row < self.rows && col < self.cols);
+ self.data[row * self.cols + col] = state;
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -212,7 +366,7 @@ mod tests {
"# };
let parsed = parse_tag(&mut input, "tag").unwrap().unwrap();
assert_eq!(parsed.attributes, "attr=\"foo\"");
- assert_eq!(parsed.body, "tag value\n");
+ assert_eq!(parsed.body, "tag value");
assert_eq!(input, "\n");
}
@@ -224,7 +378,9 @@ mod tests {
one two three four
five six seven eight
nine ten eleven twelve
- "# };
+ thirteen fourteen fifteen
+ sixteen seventeen eighteen
+ "#};
fs.insert_tree(
path!("/root"),
@@ -246,16 +402,17 @@ mod tests {
let edits = indoc! {r#"
<edits path="root/file1">
<old_text>
- five six seven eight
+ nine ten eleven twelve
</old_text>
<new_text>
- five SIX seven eight!
+ nine TEN eleven twelve!
</new_text>
</edits>
"#};
+ let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
let (buffer, edits) = parse_xml_edits(edits, |_path| {
- Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
+ Some((&buffer_snapshot, included_ranges.as_slice()))
})
.await
.unwrap();
@@ -267,8 +424,8 @@ mod tests {
assert_eq!(
edits,
&[
- (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
- (Point::new(1, 20)..Point::new(1, 20), "!".into())
+ (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
+ (Point::new(2, 22)..Point::new(2, 22), "!".into())
]
);
}
@@ -84,7 +84,7 @@ pub async fn run_evaluate(
{
write_aggregated_scores(&mut output_file, &all_results).log_err();
};
- print_run_data_dir(args.repetitions == 1);
+ print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
}
fn write_aggregated_scores(
@@ -103,8 +103,7 @@ fn write_aggregated_scores(
}
failed_count += 1;
- let err = err
- .to_string()
+ let err = format!("{err:?}")
.replace("<edits", "```xml\n<edits")
.replace("</edits>", "</edits>\n```");
writeln!(
@@ -173,6 +172,7 @@ pub async fn run_evaluate_one(
&predict_result,
&evaluation_result,
&mut std::io::stdout(),
+ std::io::stdout().is_terminal(),
)?;
}
@@ -184,6 +184,7 @@ pub async fn run_evaluate_one(
&predict_result,
&evaluation_result,
&mut results_file,
+ false,
)
.log_err();
}
@@ -196,16 +197,25 @@ fn write_eval_result(
predictions: &PredictionDetails,
evaluation_result: &EvaluationResult,
out: &mut impl Write,
+ use_color: bool,
) -> Result<()> {
writeln!(
out,
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(&example.example.expected_patch, &predictions.diff)
+ compare_diffs(
+ &example.example.expected_patch,
+ &predictions.diff,
+ use_color
+ )
)?;
writeln!(
out,
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(&predictions.diff, &example.example.expected_patch)
+ compare_diffs(
+ &predictions.diff,
+ &example.example.expected_patch,
+ use_color
+ )
)?;
writeln!(out, "{:#}", evaluation_result)?;
@@ -434,8 +444,7 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
/// Return annotated `patch_a` so that:
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
-pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
- let use_color = std::io::stdout().is_terminal();
+pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
let green = if use_color { "\x1b[32m✓ " } else { "" };
let red = if use_color { "\x1b[31m✗ " } else { "" };
let neutral = if use_color { " " } else { "" };