zeta2: Improve zeta old text matching (#42580)

Ben Kunkle , Max , Michael , Max Brunsfeld , and Agus created

This PR improves Zeta2's matching of `old_text`/`new_text` pairs, using
similar code to what we use in the edit agent. For right now, we've
duplicated the code, as opposed to trying to generalize it.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Michael <michael@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs |   4 
crates/zeta2/src/assemble_excerpts.rs             |  19 
crates/zeta2/src/retrieval_search.rs              |  14 
crates/zeta2/src/xml_edits.rs                     | 397 +++++++++++-----
crates/zeta2/src/zeta2.rs                         |  27 
crates/zeta_cli/src/evaluate.rs                   |  23 
crates/zeta_cli/src/paths.rs                      |  10 
crates/zeta_cli/src/predict.rs                    |  11 
8 files changed, 331 insertions(+), 174 deletions(-)

Detailed changes

crates/cloud_zeta2_prompt/src/retrieval_prompt.rs 🔗

@@ -11,11 +11,11 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R
     let mut prompt = SEARCH_INSTRUCTIONS.to_string();
 
     if !request.events.is_empty() {
-        writeln!(&mut prompt, "## User Edits\n")?;
+        writeln!(&mut prompt, "\n## User Edits\n\n")?;
         push_events(&mut prompt, &request.events);
     }
 
-    writeln!(&mut prompt, "## Cursor context")?;
+    writeln!(&mut prompt, "## Cursor context\n")?;
     write_codeblock(
         &request.excerpt_path,
         &[Excerpt {

crates/zeta2/src/merge_excerpts.rs → crates/zeta2/src/assemble_excerpts.rs 🔗

@@ -3,27 +3,16 @@ use edit_prediction_context::Line;
 use language::{BufferSnapshot, Point};
 use std::ops::Range;
 
-pub fn merge_excerpts(
+pub fn assemble_excerpts(
     buffer: &BufferSnapshot,
-    sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
+    merged_line_ranges: impl IntoIterator<Item = Range<Line>>,
 ) -> Vec<Excerpt> {
     let mut output = Vec::new();
-    let mut merged_ranges = Vec::<Range<Line>>::new();
-
-    for line_range in sorted_line_ranges {
-        if let Some(last_line_range) = merged_ranges.last_mut()
-            && line_range.start <= last_line_range.end
-        {
-            last_line_range.end = last_line_range.end.max(line_range.end);
-            continue;
-        }
-        merged_ranges.push(line_range);
-    }
 
     let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
     let mut outline_items = outline_items.into_iter().peekable();
 
-    for range in merged_ranges {
+    for range in merged_line_ranges {
         let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
 
         while let Some(outline_item) = outline_items.peek() {
@@ -155,7 +144,7 @@ mod tests {
 
                 let mut output = String::new();
                 cloud_zeta2_prompt::write_excerpts(
-                    merge_excerpts(&buffer.snapshot(), ranges).iter(),
+                    assemble_excerpts(&buffer.snapshot(), ranges).iter(),
                     &insertions,
                     Line(buffer.max_point().row),
                     true,

crates/zeta2/src/retrieval_search.rs 🔗

@@ -64,7 +64,7 @@ pub async fn run_retrieval_searches(
                     })?
                     .await?;
                 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
-                let mut ranges = ranges
+                let mut ranges: Vec<_> = ranges
                     .into_iter()
                     .map(|range| {
                         snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
@@ -172,11 +172,11 @@ pub async fn run_retrieval_searches(
     .await
 }
 
-fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
+pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
     ranges.sort_unstable_by(|a, b| {
         a.start
             .cmp(&b.start, snapshot)
-            .then(b.end.cmp(&b.end, snapshot))
+            .then(b.end.cmp(&a.end, snapshot))
     });
 
     let mut index = 1;
@@ -187,7 +187,9 @@ fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapsho
             .is_ge()
         {
             let removed = ranges.remove(index);
-            ranges[index - 1].end = removed.end;
+            if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
+                ranges[index - 1].end = removed.end;
+            }
         } else {
             index += 1;
         }
@@ -416,7 +418,7 @@ fn expand_to_parent_range<T: ToPoint + ToOffset>(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::merge_excerpts::merge_excerpts;
+    use crate::assemble_excerpts::assemble_excerpts;
     use cloud_zeta2_prompt::write_codeblock;
     use edit_prediction_context::Line;
     use gpui::TestAppContext;
@@ -602,7 +604,7 @@ mod tests {
 
                 write_codeblock(
                     &buffer.file().unwrap().full_path(cx),
-                    merge_excerpts(&buffer.snapshot(), excerpts).iter(),
+                    assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
                     &[],
                     Line(buffer.max_point().row),
                     false,

crates/zeta2/src/xml_edits.rs 🔗

@@ -1,8 +1,6 @@
-use anyhow::{Context as _, Result, anyhow};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
-use std::ops::Range;
-use std::path::Path;
-use std::sync::Arc;
+use anyhow::{Context as _, Result};
+use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
+use std::{cmp, ops::Range, path::Path, sync::Arc};
 
 pub async fn parse_xml_edits<'a>(
     input: &'a str,
@@ -40,128 +38,76 @@ async fn parse_xml_edits_inner<'a>(
     while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? {
         let new_text_tag =
             parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?;
-        edits.extend(resolve_new_text_old_text_in_buffer(
-            new_text_tag.body,
-            old_text_tag.body,
-            buffer,
-            context_ranges,
-        )?);
+        let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?;
+        let old_text = buffer
+            .text_for_range(match_range.clone())
+            .collect::<String>();
+        let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body);
+        edits.extend(
+            edits_within_hunk
+                .into_iter()
+                .map(move |(inner_range, inner_text)| {
+                    (
+                        buffer.anchor_after(match_range.start + inner_range.start)
+                            ..buffer.anchor_before(match_range.start + inner_range.end),
+                        inner_text,
+                    )
+                }),
+        );
     }
 
     Ok((buffer, edits))
 }
 
-fn resolve_new_text_old_text_in_buffer(
-    new_text: &str,
+fn fuzzy_match_in_ranges(
     old_text: &str,
-    buffer: &TextBufferSnapshot,
-    ranges: &[Range<Anchor>],
-) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
-    let context_offset = if old_text.is_empty() {
-        Ok(0)
-    } else {
-        let mut offset = None;
-        for range in ranges {
-            let range = range.to_offset(buffer);
-            let text = buffer.text_for_range(range.clone()).collect::<String>();
-            for (match_offset, _) in text.match_indices(old_text) {
-                if let Some(offset) = offset {
-                    let offset_match_point = buffer.offset_to_point(offset);
-                    let second_match_point = buffer.offset_to_point(range.start + match_offset);
-                    anyhow::bail!(
-                        "old_text is not unique enough:\n{}\nFound at {:?} and {:?}",
-                        old_text,
-                        offset_match_point,
-                        second_match_point
-                    );
+    buffer: &BufferSnapshot,
+    context_ranges: &[Range<Anchor>],
+) -> Result<Range<usize>> {
+    let mut state = FuzzyMatcher::new(buffer, old_text);
+    let mut best_match = None;
+    let mut tie_match_range = None;
+
+    for range in context_ranges {
+        let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
+        match (best_match_cost, state.match_range(range.to_offset(buffer))) {
+            (Some(lowest_cost), Some((new_cost, new_range))) => {
+                if new_cost == lowest_cost {
+                    tie_match_range = Some(new_range);
+                } else if new_cost < lowest_cost {
+                    tie_match_range.take();
+                    best_match = Some((new_cost, new_range));
                 }
-                offset = Some(range.start + match_offset);
             }
-        }
-        offset.ok_or_else(|| {
-            #[cfg(any(debug_assertions, feature = "eval-support"))]
-            if let Some(closest_match) = closest_old_text_match(buffer, old_text) {
-                log::info!(
-                    "Closest `old_text` match: {}",
-                    pretty_assertions::StrComparison::new(old_text, &closest_match)
-                )
+            (None, Some(new_match)) => {
+                best_match = Some(new_match);
             }
-            anyhow!("Failed to match old_text:\n{}", old_text)
-        })
-    }?;
-
-    let edits_within_hunk = language::text_diff(&old_text, &new_text);
-    Ok(edits_within_hunk
-        .into_iter()
-        .map(move |(inner_range, inner_text)| {
-            (
-                buffer.anchor_after(context_offset + inner_range.start)
-                    ..buffer.anchor_before(context_offset + inner_range.end),
-                inner_text,
-            )
-        }))
-}
-
-#[cfg(any(debug_assertions, feature = "eval-support"))]
-fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option<String> {
-    let buffer_text = buffer.text();
-    let len = old_text.len();
-
-    if len == 0 || buffer_text.len() < len {
-        return None;
+            (None, None) | (Some(_), None) => {}
+        };
     }
 
-    let mut min_score = usize::MAX;
-    let mut min_start = 0;
-
-    let old_text_bytes = old_text.as_bytes();
-    let old_alpha_count = old_text_bytes
-        .iter()
-        .filter(|&&b| b.is_ascii_alphanumeric())
-        .count();
-
-    let old_line_count = old_text.lines().count();
-
-    let mut cursor = 0;
-
-    while cursor + len <= buffer_text.len() {
-        let candidate = &buffer_text[cursor..cursor + len];
-        let candidate_bytes = candidate.as_bytes();
-
-        if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 {
-            cursor += 1;
-            continue;
-        }
-
-        let candidate_alpha_count = candidate_bytes
-            .iter()
-            .filter(|&&b| b.is_ascii_alphanumeric())
-            .count();
-
-        // If alphanumeric character count differs by more than 30%, skip
-        if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 {
-            cursor += 1;
-            continue;
-        }
-
-        let score = strsim::levenshtein(candidate, old_text);
-        if score < min_score {
-            min_score = score;
-            min_start = cursor;
-
-            if min_score <= len / 10 {
-                break;
-            }
+    if let Some((_, best_match_range)) = best_match {
+        if let Some(tie_match_range) = tie_match_range {
+            anyhow::bail!(
+                "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
+                best_match_range.clone(),
+                buffer.text_for_range(best_match_range).collect::<String>(),
+                tie_match_range.clone(),
+                buffer.text_for_range(tie_match_range).collect::<String>()
+            );
         }
-
-        cursor += 1;
+        return Ok(best_match_range);
     }
 
-    if min_score != usize::MAX {
-        Some(buffer_text[min_start..min_start + len].to_string())
-    } else {
-        None
-    }
+    anyhow::bail!(
+        "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
+        old_text,
+        context_ranges
+            .iter()
+            .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
+            .collect::<Vec<String>>()
+            .join("```\n```")
+    );
 }
 
 struct ParsedTag<'a> {
@@ -187,10 +133,218 @@ fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result<Option<ParsedTag<'a>>
             .with_context(|| format!("no `{close_tag}` tag"))?;
     let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix];
     let body = body.strip_prefix('\n').unwrap_or(body);
+    let body = body.strip_suffix('\n').unwrap_or(body);
     *input = &input[end_ix + close_tag.len()..];
     Ok(Some(ParsedTag { attributes, body }))
 }
 
+const REPLACEMENT_COST: u32 = 1;
+const INSERTION_COST: u32 = 3;
+const DELETION_COST: u32 = 10;
+
+/// A fuzzy matcher that can process text chunks incrementally
+/// and return the best match found so far at each step.
+struct FuzzyMatcher<'a> {
+    snapshot: &'a BufferSnapshot,
+    query_lines: Vec<&'a str>,
+    matrix: SearchMatrix,
+}
+
+impl<'a> FuzzyMatcher<'a> {
+    fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
+        let query_lines = old_text.lines().collect();
+        Self {
+            snapshot,
+            query_lines,
+            matrix: SearchMatrix::new(0),
+        }
+    }
+
+    fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
+        let point_range = range.to_point(&self.snapshot);
+        let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
+
+        self.matrix
+            .reset(self.query_lines.len() + 1, buffer_line_count + 1);
+        let query_line_count = self.query_lines.len();
+
+        for row in 0..query_line_count {
+            let query_line = self.query_lines[row].trim();
+            let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
+
+            self.matrix.set(
+                row + 1,
+                0,
+                SearchState::new(leading_deletion_cost, SearchDirection::Up),
+            );
+
+            let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
+
+            let mut col = 0;
+            while let Some(buffer_line) = buffer_lines.next() {
+                let buffer_line = buffer_line.trim();
+                let up = SearchState::new(
+                    self.matrix
+                        .get(row, col + 1)
+                        .cost
+                        .saturating_add(DELETION_COST),
+                    SearchDirection::Up,
+                );
+                let left = SearchState::new(
+                    self.matrix
+                        .get(row + 1, col)
+                        .cost
+                        .saturating_add(INSERTION_COST),
+                    SearchDirection::Left,
+                );
+                let diagonal = SearchState::new(
+                    if query_line == buffer_line {
+                        self.matrix.get(row, col).cost
+                    } else if fuzzy_eq(query_line, buffer_line) {
+                        self.matrix.get(row, col).cost + REPLACEMENT_COST
+                    } else {
+                        self.matrix
+                            .get(row, col)
+                            .cost
+                            .saturating_add(DELETION_COST + INSERTION_COST)
+                    },
+                    SearchDirection::Diagonal,
+                );
+                self.matrix
+                    .set(row + 1, col + 1, up.min(left).min(diagonal));
+                col += 1;
+            }
+        }
+
+        // Find all matches with the best cost
+        let mut best_cost = u32::MAX;
+        let mut matches_with_best_cost = Vec::new();
+
+        for col in 1..=buffer_line_count {
+            let cost = self.matrix.get(query_line_count, col).cost;
+            if cost < best_cost {
+                best_cost = cost;
+                matches_with_best_cost.clear();
+                matches_with_best_cost.push(col as u32);
+            } else if cost == best_cost {
+                matches_with_best_cost.push(col as u32);
+            }
+        }
+
+        // Find ranges for the matches
+        for &match_end_col in &matches_with_best_cost {
+            let mut matched_lines = 0;
+            let mut query_row = query_line_count;
+            let mut match_start_col = match_end_col;
+            while query_row > 0 && match_start_col > 0 {
+                let current = self.matrix.get(query_row, match_start_col as usize);
+                match current.direction {
+                    SearchDirection::Diagonal => {
+                        query_row -= 1;
+                        match_start_col -= 1;
+                        matched_lines += 1;
+                    }
+                    SearchDirection::Up => {
+                        query_row -= 1;
+                    }
+                    SearchDirection::Left => {
+                        match_start_col -= 1;
+                    }
+                }
+            }
+
+            let buffer_row_start = match_start_col + point_range.start.row;
+            let buffer_row_end = match_end_col + point_range.start.row;
+
+            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
+            let matched_ratio = matched_lines as f32
+                / (matched_buffer_row_count as f32).max(query_line_count as f32);
+            if matched_ratio >= 0.8 {
+                let buffer_start_ix = self
+                    .snapshot
+                    .point_to_offset(Point::new(buffer_row_start, 0));
+                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
+                    buffer_row_end - 1,
+                    self.snapshot.line_len(buffer_row_end - 1),
+                ));
+                return Some((best_cost, buffer_start_ix..buffer_end_ix));
+            }
+        }
+
+        None
+    }
+}
+
+fn fuzzy_eq(left: &str, right: &str) -> bool {
+    const THRESHOLD: f64 = 0.8;
+
+    let min_levenshtein = left.len().abs_diff(right.len());
+    let min_normalized_levenshtein =
+        1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
+    if min_normalized_levenshtein < THRESHOLD {
+        return false;
+    }
+
+    strsim::normalized_levenshtein(left, right) >= THRESHOLD
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+enum SearchDirection {
+    Up,
+    Left,
+    Diagonal,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+struct SearchState {
+    cost: u32,
+    direction: SearchDirection,
+}
+
+impl SearchState {
+    fn new(cost: u32, direction: SearchDirection) -> Self {
+        Self { cost, direction }
+    }
+}
+
+struct SearchMatrix {
+    cols: usize,
+    rows: usize,
+    data: Vec<SearchState>,
+}
+
+impl SearchMatrix {
+    fn new(cols: usize) -> Self {
+        SearchMatrix {
+            cols,
+            rows: 0,
+            data: Vec::new(),
+        }
+    }
+
+    fn reset(&mut self, rows: usize, cols: usize) {
+        self.rows = rows;
+        self.cols = cols;
+        self.data
+            .fill(SearchState::new(0, SearchDirection::Diagonal));
+        self.data.resize(
+            self.rows * self.cols,
+            SearchState::new(0, SearchDirection::Diagonal),
+        );
+    }
+
+    fn get(&self, row: usize, col: usize) -> SearchState {
+        debug_assert!(row < self.rows);
+        debug_assert!(col < self.cols);
+        self.data[row * self.cols + col]
+    }
+
+    fn set(&mut self, row: usize, col: usize, state: SearchState) {
+        debug_assert!(row < self.rows && col < self.cols);
+        self.data[row * self.cols + col] = state;
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -212,7 +366,7 @@ mod tests {
             "# };
         let parsed = parse_tag(&mut input, "tag").unwrap().unwrap();
         assert_eq!(parsed.attributes, "attr=\"foo\"");
-        assert_eq!(parsed.body, "tag value\n");
+        assert_eq!(parsed.body, "tag value");
         assert_eq!(input, "\n");
     }
 
@@ -224,7 +378,9 @@ mod tests {
             one two three four
             five six seven eight
             nine ten eleven twelve
-        "# };
+            thirteen fourteen fifteen
+            sixteen seventeen eighteen
+        "#};
 
         fs.insert_tree(
             path!("/root"),
@@ -246,16 +402,17 @@ mod tests {
         let edits = indoc! {r#"
             <edits path="root/file1">
             <old_text>
-            five six seven eight
+            nine ten eleven twelve
             </old_text>
             <new_text>
-            five SIX seven eight!
+            nine TEN eleven twelve!
             </new_text>
             </edits>
         "#};
 
+        let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
         let (buffer, edits) = parse_xml_edits(edits, |_path| {
-            Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
+            Some((&buffer_snapshot, included_ranges.as_slice()))
         })
         .await
         .unwrap();
@@ -267,8 +424,8 @@ mod tests {
         assert_eq!(
             edits,
             &[
-                (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
-                (Point::new(1, 20)..Point::new(1, 20), "!".into())
+                (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
+                (Point::new(2, 22)..Point::new(2, 22), "!".into())
             ]
         );
     }

crates/zeta2/src/zeta2.rs 🔗

@@ -42,14 +42,14 @@ use util::rel_path::RelPathBuf;
 use util::{LogErrorFuture, TryFutureExt};
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 
-pub mod merge_excerpts;
+pub mod assemble_excerpts;
 mod prediction;
 mod provider;
 pub mod retrieval_search;
 pub mod udiff;
 mod xml_edits;
 
-use crate::merge_excerpts::merge_excerpts;
+use crate::assemble_excerpts::assemble_excerpts;
 use crate::prediction::EditPrediction;
 pub use crate::prediction::EditPredictionId;
 pub use provider::ZetaEditPredictionProvider;
@@ -820,16 +820,8 @@ impl Zeta {
                             })
                         {
                             let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
-                            let range_ix = ranges
-                                .binary_search_by(|probe| {
-                                    probe
-                                        .start
-                                        .cmp(&excerpt_anchor_range.start, buffer)
-                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
-                                })
-                                .unwrap_or_else(|ix| ix);
-
-                            ranges.insert(range_ix, excerpt_anchor_range);
+                            ranges.push(excerpt_anchor_range);
+                            retrieval_search::merge_anchor_ranges(ranges, buffer);
                             let last_ix = included_files.len() - 1;
                             included_files.swap(buffer_ix, last_ix);
                         } else {
@@ -844,13 +836,14 @@ impl Zeta {
                         let included_files = included_files
                             .iter()
                             .map(|(_, snapshot, path, ranges)| {
-                                let excerpts = merge_excerpts(
-                                    &snapshot,
-                                    ranges.iter().map(|range| {
+                                let ranges = ranges
+                                    .iter()
+                                    .map(|range| {
                                         let point_range = range.to_point(&snapshot);
                                         Line(point_range.start.row)..Line(point_range.end.row)
-                                    }),
-                                );
+                                    })
+                                    .collect::<Vec<_>>();
+                                let excerpts = assemble_excerpts(&snapshot, ranges);
                                 predict_edits_v3::IncludedFile {
                                     path: path.clone(),
                                     max_row: Line(snapshot.max_point().row),

crates/zeta_cli/src/evaluate.rs 🔗

@@ -84,7 +84,7 @@ pub async fn run_evaluate(
     {
         write_aggregated_scores(&mut output_file, &all_results).log_err();
     };
-    print_run_data_dir(args.repetitions == 1);
+    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
 }
 
 fn write_aggregated_scores(
@@ -103,8 +103,7 @@ fn write_aggregated_scores(
                 }
 
                 failed_count += 1;
-                let err = err
-                    .to_string()
+                let err = format!("{err:?}")
                     .replace("<edits", "```xml\n<edits")
                     .replace("</edits>", "</edits>\n```");
                 writeln!(
@@ -173,6 +172,7 @@ pub async fn run_evaluate_one(
             &predict_result,
             &evaluation_result,
             &mut std::io::stdout(),
+            std::io::stdout().is_terminal(),
         )?;
     }
 
@@ -184,6 +184,7 @@ pub async fn run_evaluate_one(
             &predict_result,
             &evaluation_result,
             &mut results_file,
+            false,
         )
         .log_err();
     }
@@ -196,16 +197,25 @@ fn write_eval_result(
     predictions: &PredictionDetails,
     evaluation_result: &EvaluationResult,
     out: &mut impl Write,
+    use_color: bool,
 ) -> Result<()> {
     writeln!(
         out,
         "## Expected edit prediction:\n\n```diff\n{}\n```\n",
-        compare_diffs(&example.example.expected_patch, &predictions.diff)
+        compare_diffs(
+            &example.example.expected_patch,
+            &predictions.diff,
+            use_color
+        )
     )?;
     writeln!(
         out,
         "## Actual edit prediction:\n\n```diff\n{}\n```\n",
-        compare_diffs(&predictions.diff, &example.example.expected_patch)
+        compare_diffs(
+            &predictions.diff,
+            &example.example.expected_patch,
+            use_color
+        )
     )?;
     writeln!(out, "{:#}", evaluation_result)?;
 
@@ -434,8 +444,7 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul
 /// Return annotated `patch_a` so that:
 /// Additions and deletions that are not present in `patch_b` will be highlighted in red.
 /// Additions and deletions that are present in `patch_b` will be highlighted in green.
-pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String {
-    let use_color = std::io::stdout().is_terminal();
+pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
     let green = if use_color { "\x1b[32m✓ " } else { "" };
     let red = if use_color { "\x1b[31m✗ " } else { "" };
     let neutral = if use_color { "  " } else { "" };

crates/zeta_cli/src/paths.rs 🔗

@@ -13,7 +13,7 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
 pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
     LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
 
-pub fn print_run_data_dir(deep: bool) {
+pub fn print_run_data_dir(deep: bool, use_color: bool) {
     println!("\n## Run Data\n");
     let mut files = Vec::new();
 
@@ -25,18 +25,22 @@ pub fn print_run_data_dir(deep: bool) {
                 let path = file.unwrap().path();
                 let path = path.strip_prefix(&current_dir).unwrap_or(&path);
                 files.push(format!(
-                    "- {}/\x1b[34m{}\x1b[0m",
+                    "- {}/{}{}{}",
                     path.parent().unwrap().display(),
+                    if use_color { "\x1b[34m" } else { "" },
                     path.file_name().unwrap().display(),
+                    if use_color { "\x1b[0m" } else { "" },
                 ));
             }
         } else {
             let path = file.path();
             let path = path.strip_prefix(&current_dir).unwrap_or(&path);
             files.push(format!(
-                "- {}/\x1b[34m{}\x1b[0m",
+                "- {}/{}{}{}",
                 path.parent().unwrap().display(),
+                if use_color { "\x1b[34m" } else { "" },
                 path.file_name().unwrap().display(),
+                if use_color { "\x1b[0m" } else { "" }
             ));
         }
     }

crates/zeta_cli/src/predict.rs 🔗

@@ -13,7 +13,7 @@ use language::{Anchor, Buffer, Point};
 use project::Project;
 use serde::Deserialize;
 use std::fs;
-use std::io::Write;
+use std::io::{IsTerminal, Write};
 use std::ops::Range;
 use std::path::PathBuf;
 use std::sync::Arc;
@@ -98,7 +98,7 @@ pub async fn run_zeta2_predict(
     .unwrap();
     result.write(args.format, std::io::stdout()).unwrap();
 
-    print_run_data_dir(true);
+    print_run_data_dir(true, std::io::stdout().is_terminal());
 }
 
 pub async fn zeta2_predict(
@@ -289,8 +289,11 @@ pub async fn zeta2_predict(
             let new_text = prediction
                 .buffer
                 .update(cx, |buffer, cx| {
-                    buffer.edit(prediction.edits.iter().cloned(), None, cx);
-                    buffer.text()
+                    let branch = buffer.branch(cx);
+                    branch.update(cx, |branch, cx| {
+                        branch.edit(prediction.edits.iter().cloned(), None, cx);
+                        branch.text()
+                    })
                 })
                 .unwrap();
             language::unified_diff(&old_text, &new_text)