diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs index a11c56da41384257b8331a31161224c9e25d0894..e334674ef8004b485608e3864cf1e4e8d4c97cdb 100644 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs @@ -11,11 +11,11 @@ pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> R let mut prompt = SEARCH_INSTRUCTIONS.to_string(); if !request.events.is_empty() { - writeln!(&mut prompt, "## User Edits\n")?; + writeln!(&mut prompt, "\n## User Edits\n\n")?; push_events(&mut prompt, &request.events); } - writeln!(&mut prompt, "## Cursor context")?; + writeln!(&mut prompt, "## Cursor context\n")?; write_codeblock( &request.excerpt_path, &[Excerpt { diff --git a/crates/zeta2/src/merge_excerpts.rs b/crates/zeta2/src/assemble_excerpts.rs similarity index 91% rename from crates/zeta2/src/merge_excerpts.rs rename to crates/zeta2/src/assemble_excerpts.rs index 846d8034a8c2e88b8552dc8c9d48af6ccdc5efcf..f2a5b5adb1fcffab945cd9bdb88153bc5e494138 100644 --- a/crates/zeta2/src/merge_excerpts.rs +++ b/crates/zeta2/src/assemble_excerpts.rs @@ -3,27 +3,16 @@ use edit_prediction_context::Line; use language::{BufferSnapshot, Point}; use std::ops::Range; -pub fn merge_excerpts( +pub fn assemble_excerpts( buffer: &BufferSnapshot, - sorted_line_ranges: impl IntoIterator>, + merged_line_ranges: impl IntoIterator>, ) -> Vec { let mut output = Vec::new(); - let mut merged_ranges = Vec::>::new(); - - for line_range in sorted_line_ranges { - if let Some(last_line_range) = merged_ranges.last_mut() - && line_range.start <= last_line_range.end - { - last_line_range.end = last_line_range.end.max(line_range.end); - continue; - } - merged_ranges.push(line_range); - } let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); let mut outline_items = outline_items.into_iter().peekable(); - for range in merged_ranges { + for range in merged_line_ranges { let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0); while let Some(outline_item) = outline_items.peek() { @@ -155,7 +144,7 @@ mod tests { let mut output = String::new(); cloud_zeta2_prompt::write_excerpts( - merge_excerpts(&buffer.snapshot(), ranges).iter(), + assemble_excerpts(&buffer.snapshot(), ranges).iter(), &insertions, Line(buffer.max_point().row), true, diff --git a/crates/zeta2/src/retrieval_search.rs b/crates/zeta2/src/retrieval_search.rs index d642c2edaa1fbc897b3c74b0b5c8b1fb71227e84..fd7364cf23ac66fe9baf2f911868ef251d2d25cf 100644 --- a/crates/zeta2/src/retrieval_search.rs +++ b/crates/zeta2/src/retrieval_search.rs @@ -64,7 +64,7 @@ pub async fn run_retrieval_searches( })? .await?; let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - let mut ranges = ranges + let mut ranges: Vec<_> = ranges .into_iter() .map(|range| { snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end) @@ -172,11 +172,11 @@ pub async fn run_retrieval_searches( .await } -fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapshot) { +pub(crate) fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapshot) { ranges.sort_unstable_by(|a, b| { a.start .cmp(&b.start, snapshot) - .then(b.end.cmp(&b.end, snapshot)) + .then(b.end.cmp(&a.end, snapshot)) }); let mut index = 1; @@ -187,7 +187,9 @@ fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapsho .is_ge() { let removed = ranges.remove(index); - ranges[index - 1].end = removed.end; + if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() { + ranges[index - 1].end = removed.end; + } } else { index += 1; } @@ -416,7 +418,7 @@ fn expand_to_parent_range( #[cfg(test)] mod tests { use super::*; - use crate::merge_excerpts::merge_excerpts; + use crate::assemble_excerpts::assemble_excerpts; use cloud_zeta2_prompt::write_codeblock; use edit_prediction_context::Line; use gpui::TestAppContext; @@ -602,7 +604,7 @@ mod tests { write_codeblock( &buffer.file().unwrap().full_path(cx), - merge_excerpts(&buffer.snapshot(), excerpts).iter(), + assemble_excerpts(&buffer.snapshot(), excerpts).iter(), &[], Line(buffer.max_point().row), false, diff --git a/crates/zeta2/src/xml_edits.rs b/crates/zeta2/src/xml_edits.rs index d1eea285d6861dc4cbe6fe65a133453d5b06adaf..468efa8b202141c4cca04459233ea91c5bff9d44 100644 --- a/crates/zeta2/src/xml_edits.rs +++ b/crates/zeta2/src/xml_edits.rs @@ -1,8 +1,6 @@ -use anyhow::{Context as _, Result, anyhow}; -use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot}; -use std::ops::Range; -use std::path::Path; -use std::sync::Arc; +use anyhow::{Context as _, Result}; +use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point}; +use std::{cmp, ops::Range, path::Path, sync::Arc}; pub async fn parse_xml_edits<'a>( input: &'a str, @@ -40,128 +38,76 @@ async fn parse_xml_edits_inner<'a>( while let Some(old_text_tag) = parse_tag(&mut input, "old_text")? { let new_text_tag = parse_tag(&mut input, "new_text")?.context("no new_text tag following old_text")?; - edits.extend(resolve_new_text_old_text_in_buffer( - new_text_tag.body, - old_text_tag.body, - buffer, - context_ranges, - )?); + let match_range = fuzzy_match_in_ranges(old_text_tag.body, buffer, context_ranges)?; + let old_text = buffer + .text_for_range(match_range.clone()) + .collect::(); + let edits_within_hunk = language::text_diff(&old_text, &new_text_tag.body); + edits.extend( + edits_within_hunk + .into_iter() + .map(move |(inner_range, inner_text)| { + ( + buffer.anchor_after(match_range.start + inner_range.start) + ..buffer.anchor_before(match_range.start + inner_range.end), + inner_text, + ) + }), + ); } Ok((buffer, edits)) } -fn resolve_new_text_old_text_in_buffer( - new_text: &str, +fn fuzzy_match_in_ranges( old_text: &str, - buffer: &TextBufferSnapshot, - ranges: &[Range], -) -> Result, Arc)>, anyhow::Error> { - let context_offset = if old_text.is_empty() { - Ok(0) - } else { - let mut offset = None; - for range in ranges { - let range = range.to_offset(buffer); - let text = buffer.text_for_range(range.clone()).collect::(); - for (match_offset, _) in text.match_indices(old_text) { - if let Some(offset) = offset { - let offset_match_point = buffer.offset_to_point(offset); - let second_match_point = buffer.offset_to_point(range.start + match_offset); - anyhow::bail!( - "old_text is not unique enough:\n{}\nFound at {:?} and {:?}", - old_text, - offset_match_point, - second_match_point - ); + buffer: &BufferSnapshot, + context_ranges: &[Range], +) -> Result> { + let mut state = FuzzyMatcher::new(buffer, old_text); + let mut best_match = None; + let mut tie_match_range = None; + + for range in context_ranges { + let best_match_cost = best_match.as_ref().map(|(score, _)| *score); + match (best_match_cost, state.match_range(range.to_offset(buffer))) { + (Some(lowest_cost), Some((new_cost, new_range))) => { + if new_cost == lowest_cost { + tie_match_range = Some(new_range); + } else if new_cost < lowest_cost { + tie_match_range.take(); + best_match = Some((new_cost, new_range)); } - offset = Some(range.start + match_offset); } - } - offset.ok_or_else(|| { - #[cfg(any(debug_assertions, feature = "eval-support"))] - if let Some(closest_match) = closest_old_text_match(buffer, old_text) { - log::info!( - "Closest `old_text` match: {}", - pretty_assertions::StrComparison::new(old_text, &closest_match) - ) + (None, Some(new_match)) => { + best_match = Some(new_match); } - anyhow!("Failed to match old_text:\n{}", old_text) - }) - }?; - - let edits_within_hunk = language::text_diff(&old_text, &new_text); - Ok(edits_within_hunk - .into_iter() - .map(move |(inner_range, inner_text)| { - ( - buffer.anchor_after(context_offset + inner_range.start) - ..buffer.anchor_before(context_offset + inner_range.end), - inner_text, - ) - })) -} - -#[cfg(any(debug_assertions, feature = "eval-support"))] -fn closest_old_text_match(buffer: &TextBufferSnapshot, old_text: &str) -> Option { - let buffer_text = buffer.text(); - let len = old_text.len(); - - if len == 0 || buffer_text.len() < len { - return None; + (None, None) | (Some(_), None) => {} + }; } - let mut min_score = usize::MAX; - let mut min_start = 0; - - let old_text_bytes = old_text.as_bytes(); - let old_alpha_count = old_text_bytes - .iter() - .filter(|&&b| b.is_ascii_alphanumeric()) - .count(); - - let old_line_count = old_text.lines().count(); - - let mut cursor = 0; - - while cursor + len <= buffer_text.len() { - let candidate = &buffer_text[cursor..cursor + len]; - let candidate_bytes = candidate.as_bytes(); - - if usize::abs_diff(candidate.lines().count(), old_line_count) > 4 { - cursor += 1; - continue; - } - - let candidate_alpha_count = candidate_bytes - .iter() - .filter(|&&b| b.is_ascii_alphanumeric()) - .count(); - - // If alphanumeric character count differs by more than 30%, skip - if usize::abs_diff(old_alpha_count, candidate_alpha_count) * 10 > old_alpha_count * 3 { - cursor += 1; - continue; - } - - let score = strsim::levenshtein(candidate, old_text); - if score < min_score { - min_score = score; - min_start = cursor; - - if min_score <= len / 10 { - break; - } + if let Some((_, best_match_range)) = best_match { + if let Some(tie_match_range) = tie_match_range { + anyhow::bail!( + "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}", + best_match_range.clone(), + buffer.text_for_range(best_match_range).collect::(), + tie_match_range.clone(), + buffer.text_for_range(tie_match_range).collect::() + ); } - - cursor += 1; + return Ok(best_match_range); } - if min_score != usize::MAX { - Some(buffer_text[min_start..min_start + len].to_string()) - } else { - None - } + anyhow::bail!( + "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```", + old_text, + context_ranges + .iter() + .map(|range| buffer.text_for_range(range.clone()).collect::()) + .collect::>() + .join("```\n```") + ); } struct ParsedTag<'a> { @@ -187,10 +133,218 @@ fn parse_tag<'a>(input: &mut &'a str, tag: &str) -> Result> .with_context(|| format!("no `{close_tag}` tag"))?; let body = &input[closing_bracket_ix + '>'.len_utf8()..end_ix]; let body = body.strip_prefix('\n').unwrap_or(body); + let body = body.strip_suffix('\n').unwrap_or(body); *input = &input[end_ix + close_tag.len()..]; Ok(Some(ParsedTag { attributes, body })) } +const REPLACEMENT_COST: u32 = 1; +const INSERTION_COST: u32 = 3; +const DELETION_COST: u32 = 10; + +/// A fuzzy matcher that can process text chunks incrementally +/// and return the best match found so far at each step. +struct FuzzyMatcher<'a> { + snapshot: &'a BufferSnapshot, + query_lines: Vec<&'a str>, + matrix: SearchMatrix, +} + +impl<'a> FuzzyMatcher<'a> { + fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self { + let query_lines = old_text.lines().collect(); + Self { + snapshot, + query_lines, + matrix: SearchMatrix::new(0), + } + } + + fn match_range(&mut self, range: Range) -> Option<(u32, Range)> { + let point_range = range.to_point(&self.snapshot); + let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize; + + self.matrix + .reset(self.query_lines.len() + 1, buffer_line_count + 1); + let query_line_count = self.query_lines.len(); + + for row in 0..query_line_count { + let query_line = self.query_lines[row].trim(); + let leading_deletion_cost = (row + 1) as u32 * DELETION_COST; + + self.matrix.set( + row + 1, + 0, + SearchState::new(leading_deletion_cost, SearchDirection::Up), + ); + + let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines(); + + let mut col = 0; + while let Some(buffer_line) = buffer_lines.next() { + let buffer_line = buffer_line.trim(); + let up = SearchState::new( + self.matrix + .get(row, col + 1) + .cost + .saturating_add(DELETION_COST), + SearchDirection::Up, + ); + let left = SearchState::new( + self.matrix + .get(row + 1, col) + .cost + .saturating_add(INSERTION_COST), + SearchDirection::Left, + ); + let diagonal = SearchState::new( + if query_line == buffer_line { + self.matrix.get(row, col).cost + } else if fuzzy_eq(query_line, buffer_line) { + self.matrix.get(row, col).cost + REPLACEMENT_COST + } else { + self.matrix + .get(row, col) + .cost + .saturating_add(DELETION_COST + INSERTION_COST) + }, + SearchDirection::Diagonal, + ); + self.matrix + .set(row + 1, col + 1, up.min(left).min(diagonal)); + col += 1; + } + } + + // Find all matches with the best cost + let mut best_cost = u32::MAX; + let mut matches_with_best_cost = Vec::new(); + + for col in 1..=buffer_line_count { + let cost = self.matrix.get(query_line_count, col).cost; + if cost < best_cost { + best_cost = cost; + matches_with_best_cost.clear(); + matches_with_best_cost.push(col as u32); + } else if cost == best_cost { + matches_with_best_cost.push(col as u32); + } + } + + // Find ranges for the matches + for &match_end_col in &matches_with_best_cost { + let mut matched_lines = 0; + let mut query_row = query_line_count; + let mut match_start_col = match_end_col; + while query_row > 0 && match_start_col > 0 { + let current = self.matrix.get(query_row, match_start_col as usize); + match current.direction { + SearchDirection::Diagonal => { + query_row -= 1; + match_start_col -= 1; + matched_lines += 1; + } + SearchDirection::Up => { + query_row -= 1; + } + SearchDirection::Left => { + match_start_col -= 1; + } + } + } + + let buffer_row_start = match_start_col + point_range.start.row; + let buffer_row_end = match_end_col + point_range.start.row; + + let matched_buffer_row_count = buffer_row_end - buffer_row_start; + let matched_ratio = matched_lines as f32 + / (matched_buffer_row_count as f32).max(query_line_count as f32); + if matched_ratio >= 0.8 { + let buffer_start_ix = self + .snapshot + .point_to_offset(Point::new(buffer_row_start, 0)); + let buffer_end_ix = self.snapshot.point_to_offset(Point::new( + buffer_row_end - 1, + self.snapshot.line_len(buffer_row_end - 1), + )); + return Some((best_cost, buffer_start_ix..buffer_end_ix)); + } + } + + None + } +} + +fn fuzzy_eq(left: &str, right: &str) -> bool { + const THRESHOLD: f64 = 0.8; + + let min_levenshtein = left.len().abs_diff(right.len()); + let min_normalized_levenshtein = + 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64); + if min_normalized_levenshtein < THRESHOLD { + return false; + } + + strsim::normalized_levenshtein(left, right) >= THRESHOLD +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum SearchDirection { + Up, + Left, + Diagonal, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct SearchState { + cost: u32, + direction: SearchDirection, +} + +impl SearchState { + fn new(cost: u32, direction: SearchDirection) -> Self { + Self { cost, direction } + } +} + +struct SearchMatrix { + cols: usize, + rows: usize, + data: Vec, +} + +impl SearchMatrix { + fn new(cols: usize) -> Self { + SearchMatrix { + cols, + rows: 0, + data: Vec::new(), + } + } + + fn reset(&mut self, rows: usize, cols: usize) { + self.rows = rows; + self.cols = cols; + self.data + .fill(SearchState::new(0, SearchDirection::Diagonal)); + self.data.resize( + self.rows * self.cols, + SearchState::new(0, SearchDirection::Diagonal), + ); + } + + fn get(&self, row: usize, col: usize) -> SearchState { + debug_assert!(row < self.rows); + debug_assert!(col < self.cols); + self.data[row * self.cols + col] + } + + fn set(&mut self, row: usize, col: usize, state: SearchState) { + debug_assert!(row < self.rows && col < self.cols); + self.data[row * self.cols + col] = state; + } +} + #[cfg(test)] mod tests { use super::*; @@ -212,7 +366,7 @@ mod tests { "# }; let parsed = parse_tag(&mut input, "tag").unwrap().unwrap(); assert_eq!(parsed.attributes, "attr=\"foo\""); - assert_eq!(parsed.body, "tag value\n"); + assert_eq!(parsed.body, "tag value"); assert_eq!(input, "\n"); } @@ -224,7 +378,9 @@ mod tests { one two three four five six seven eight nine ten eleven twelve - "# }; + thirteen fourteen fifteen + sixteen seventeen eighteen + "#}; fs.insert_tree( path!("/root"), @@ -246,16 +402,17 @@ mod tests { let edits = indoc! {r#" - five six seven eight + nine ten eleven twelve - five SIX seven eight! + nine TEN eleven twelve! "#}; + let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)]; let (buffer, edits) = parse_xml_edits(edits, |_path| { - Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_])) + Some((&buffer_snapshot, included_ranges.as_slice())) }) .await .unwrap(); @@ -267,8 +424,8 @@ mod tests { assert_eq!( edits, &[ - (Point::new(1, 5)..Point::new(1, 8), "SIX".into()), - (Point::new(1, 20)..Point::new(1, 20), "!".into()) + (Point::new(2, 5)..Point::new(2, 8), "TEN".into()), + (Point::new(2, 22)..Point::new(2, 22), "!".into()) ] ); } diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 7322cb4b6e6882ad2f3597abb505224cc24dbd5e..1521fbd9291c7a69cc56152d193734f41cf0451e 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -42,14 +42,14 @@ use util::rel_path::RelPathBuf; use util::{LogErrorFuture, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -pub mod merge_excerpts; +pub mod assemble_excerpts; mod prediction; mod provider; pub mod retrieval_search; pub mod udiff; mod xml_edits; -use crate::merge_excerpts::merge_excerpts; +use crate::assemble_excerpts::assemble_excerpts; use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; pub use provider::ZetaEditPredictionProvider; @@ -820,16 +820,8 @@ impl Zeta { }) { let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; - let range_ix = ranges - .binary_search_by(|probe| { - probe - .start - .cmp(&excerpt_anchor_range.start, buffer) - .then(excerpt_anchor_range.end.cmp(&probe.end, buffer)) - }) - .unwrap_or_else(|ix| ix); - - ranges.insert(range_ix, excerpt_anchor_range); + ranges.push(excerpt_anchor_range); + retrieval_search::merge_anchor_ranges(ranges, buffer); let last_ix = included_files.len() - 1; included_files.swap(buffer_ix, last_ix); } else { @@ -844,13 +836,14 @@ impl Zeta { let included_files = included_files .iter() .map(|(_, snapshot, path, ranges)| { - let excerpts = merge_excerpts( - &snapshot, - ranges.iter().map(|range| { + let ranges = ranges + .iter() + .map(|range| { let point_range = range.to_point(&snapshot); Line(point_range.start.row)..Line(point_range.end.row) - }), - ); + }) + .collect::>(); + let excerpts = assemble_excerpts(&snapshot, ranges); predict_edits_v3::IncludedFile { path: path.clone(), max_row: Line(snapshot.max_point().row), diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/zeta_cli/src/evaluate.rs index d255d1a56102d836cc18ce4df10586edad0ca957..b0b3820362889051e3e5c0eef03ef10c7f0d6fa8 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/zeta_cli/src/evaluate.rs @@ -84,7 +84,7 @@ pub async fn run_evaluate( { write_aggregated_scores(&mut output_file, &all_results).log_err(); }; - print_run_data_dir(args.repetitions == 1); + print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal()); } fn write_aggregated_scores( @@ -103,8 +103,7 @@ fn write_aggregated_scores( } failed_count += 1; - let err = err - .to_string() + let err = format!("{err:?}") .replace("", "\n```"); writeln!( @@ -173,6 +172,7 @@ pub async fn run_evaluate_one( &predict_result, &evaluation_result, &mut std::io::stdout(), + std::io::stdout().is_terminal(), )?; } @@ -184,6 +184,7 @@ pub async fn run_evaluate_one( &predict_result, &evaluation_result, &mut results_file, + false, ) .log_err(); } @@ -196,16 +197,25 @@ fn write_eval_result( predictions: &PredictionDetails, evaluation_result: &EvaluationResult, out: &mut impl Write, + use_color: bool, ) -> Result<()> { writeln!( out, "## Expected edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs(&example.example.expected_patch, &predictions.diff) + compare_diffs( + &example.example.expected_patch, + &predictions.diff, + use_color + ) )?; writeln!( out, "## Actual edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs(&predictions.diff, &example.example.expected_patch) + compare_diffs( + &predictions.diff, + &example.example.expected_patch, + use_color + ) )?; writeln!(out, "{:#}", evaluation_result)?; @@ -434,8 +444,7 @@ pub fn evaluate(example: &Example, preds: &PredictionDetails) -> EvaluationResul /// Return annotated `patch_a` so that: /// Additions and deletions that are not present in `patch_b` will be highlighted in red. /// Additions and deletions that are present in `patch_b` will be highlighted in green. -pub fn compare_diffs(patch_a: &str, patch_b: &str) -> String { - let use_color = std::io::stdout().is_terminal(); +pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String { let green = if use_color { "\x1b[32m✓ " } else { "" }; let red = if use_color { "\x1b[31m✗ " } else { "" }; let neutral = if use_color { " " } else { "" }; diff --git a/crates/zeta_cli/src/paths.rs b/crates/zeta_cli/src/paths.rs index 15c4941f3dacce0b9a06c15daee431014b12944d..3cc2beec5bd50380b9eef8b502dcba0ccba32772 100644 --- a/crates/zeta_cli/src/paths.rs +++ b/crates/zeta_cli/src/paths.rs @@ -13,7 +13,7 @@ pub static RUN_DIR: LazyLock = LazyLock::new(|| { pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("latest")); -pub fn print_run_data_dir(deep: bool) { +pub fn print_run_data_dir(deep: bool, use_color: bool) { println!("\n## Run Data\n"); let mut files = Vec::new(); @@ -25,18 +25,22 @@ pub fn print_run_data_dir(deep: bool) { let path = file.unwrap().path(); let path = path.strip_prefix(¤t_dir).unwrap_or(&path); files.push(format!( - "- {}/\x1b[34m{}\x1b[0m", + "- {}/{}{}{}", path.parent().unwrap().display(), + if use_color { "\x1b[34m" } else { "" }, path.file_name().unwrap().display(), + if use_color { "\x1b[0m" } else { "" }, )); } } else { let path = file.path(); let path = path.strip_prefix(¤t_dir).unwrap_or(&path); files.push(format!( - "- {}/\x1b[34m{}\x1b[0m", + "- {}/{}{}{}", path.parent().unwrap().display(), + if use_color { "\x1b[34m" } else { "" }, path.file_name().unwrap().display(), + if use_color { "\x1b[0m" } else { "" } )); } } diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 1f419fd09a87d1270d73bc90fe4b312cbaf0b4a4..0618cf38bafd15a6b8a50b03cb745c9d3365cbf8 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -13,7 +13,7 @@ use language::{Anchor, Buffer, Point}; use project::Project; use serde::Deserialize; use std::fs; -use std::io::Write; +use std::io::{IsTerminal, Write}; use std::ops::Range; use std::path::PathBuf; use std::sync::Arc; @@ -98,7 +98,7 @@ pub async fn run_zeta2_predict( .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); - print_run_data_dir(true); + print_run_data_dir(true, std::io::stdout().is_terminal()); } pub async fn zeta2_predict( @@ -289,8 +289,11 @@ pub async fn zeta2_predict( let new_text = prediction .buffer .update(cx, |buffer, cx| { - buffer.edit(prediction.edits.iter().cloned(), None, cx); - buffer.text() + let branch = buffer.branch(cx); + branch.update(cx, |branch, cx| { + branch.edit(prediction.edits.iter().cloned(), None, cx); + branch.text() + }) }) .unwrap(); language::unified_diff(&old_text, &new_text)