diff --git a/Cargo.lock b/Cargo.lock index f7597693960b2c9e66121794f9c99cdb8d6ddcea..e1a5a11ad0c0549791545cd7e020e283decb5b53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22542,6 +22542,7 @@ name = "zeta_prompt" version = "0.1.0" dependencies = [ "anyhow", + "imara-diff", "indoc", "serde", "strum 0.27.2", diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 1ba8b27aa785024a47a09c3299a1f3786a028ccf..ea7233cd976148f5eb726730635e0efaf6ceef86 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte }); } +#[gpui::test] +async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + + fs.insert_tree( + "/root", + json!({ + "foo.txt": "hello" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project + .find_project_path(path!("root/foo.txt"), cx) + .unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(0, 5)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let excerpt_length = request.input.cursor_excerpt.len(); + respond_tx + .send(PredictEditsV3Response { + request_id: Uuid::new_v4().to_string(), + output: "hello<|user_cursor|> world".to_string(), + editable_range: 0..excerpt_length, + model_version: None, + }) + .unwrap(); + + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + let prediction = ep_store + .prediction_at(&buffer, None, &project, cx) + .expect("should have prediction"); + let snapshot = buffer.read(cx).snapshot(); + let edits: Vec<_> = prediction + .edits + .iter() + .map(|(range, text)| (range.to_offset(&snapshot), text.clone())) + .collect(); + + assert_eq!(edits, vec![(5..5, " world".into())]); + }); +} + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 4486cde22c3429568bf29f152d0f5f2ded59e8f4..a7da51173eefbcdb9e014f7dcca917e6ebebebf5 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -1,10 +1,11 @@ -use crate::udiff::DiffLine; use anyhow::{Context as _, Result}; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; use telemetry_events::EditPredictionRating; -pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; +pub use zeta_prompt::udiff::{ + CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch, +}; pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// Maximum cursor file size to capture (64KB). @@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; /// falling back to git-based loading. pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024; -/// Encodes a cursor position into a diff patch by adding a comment line with a caret -/// pointing to the cursor column. -/// -/// The cursor offset is relative to the start of the new text content (additions and context lines). -/// Returns the patch with cursor marker comment lines inserted after the relevant addition line. -pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { - let Some(cursor_offset) = cursor_offset else { - return patch.to_string(); - }; - - let mut result = String::new(); - let mut line_start_offset = 0usize; - - for line in patch.lines() { - if matches!( - DiffLine::parse(line), - DiffLine::Garbage(content) - if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) - ) { - continue; - } - - if !result.is_empty() { - result.push('\n'); - } - result.push_str(line); - - match DiffLine::parse(line) { - DiffLine::Addition(content) => { - let line_end_offset = line_start_offset + content.len(); - - if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { - let cursor_column = cursor_offset - line_start_offset; - - result.push('\n'); - result.push('#'); - for _ in 0..cursor_column { - result.push(' '); - } - write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); - } - - line_start_offset = line_end_offset + 1; - } - DiffLine::Context(content) => { - line_start_offset += content.len() + 1; - } - _ => {} - } - } - - if patch.ends_with('\n') { - result.push('\n'); - } - - result -} - #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] pub struct ExampleSpec { #[serde(default)] @@ -509,53 +452,7 @@ impl ExampleSpec { pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option)> { self.expected_patches .iter() - .map(|patch| { - let mut clean_patch = String::new(); - let mut cursor_offset: Option = None; - let mut line_start_offset = 0usize; - let mut prev_line_start_offset = 0usize; - - for line in patch.lines() { - let diff_line = DiffLine::parse(line); - - match &diff_line { - DiffLine::Garbage(content) - if content.starts_with('#') - && content.contains(CURSOR_POSITION_MARKER) => - { - let caret_column = if let Some(caret_pos) = content.find('^') { - caret_pos - } else if let Some(_) = content.find('<') { - 0 - } else { - continue; - }; - let cursor_column = caret_column.saturating_sub('#'.len_utf8()); - cursor_offset = Some(prev_line_start_offset + cursor_column); - } - _ => { - if !clean_patch.is_empty() { - clean_patch.push('\n'); - } - clean_patch.push_str(line); - - match diff_line { - DiffLine::Addition(content) | DiffLine::Context(content) => { - prev_line_start_offset = line_start_offset; - line_start_offset += content.len() + 1; - } - _ => {} - } - } - } - } - - if patch.ends_with('\n') && !clean_patch.is_empty() { - clean_patch.push('\n'); - } - - (clean_patch, cursor_offset) - }) + .map(|patch| extract_cursor_from_patch(patch)) .collect() } diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index fdfe3ebcf06c8319f5ce00066fa279d79eda7eea..b4556e58b9247624e2d4caeddb5614ff5000d854 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput}; use std::{env, ops::Range, path::Path, sync::Arc}; use zeta_prompt::{ - CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, - prompt_input_contains_special_tokens, stop_tokens_for_format, + ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output, + parsed_output_from_editable_region, prompt_input_contains_special_tokens, + stop_tokens_for_format, zeta1::{self, EDITABLE_REGION_END_MARKER}, }; @@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta( let parsed_output = output_text.map(|text| ParsedOutput { new_editable_region: text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: None, }); (request_id, parsed_output, None, None) @@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta( let request_id = EditPredictionId(response.request_id.into()); let output_text = Some(response.output).filter(|s| !s.is_empty()); let model_version = response.model_version; - let parsed_output = ParsedOutput { - new_editable_region: output_text.unwrap_or_default(), - range_in_excerpt: response.editable_range, - }; + let parsed_output = parsed_output_from_editable_region( + response.editable_range, + output_text.unwrap_or_default(), + ); Some((request_id, Some(parsed_output), model_version, usage)) }) @@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta( let Some(ParsedOutput { new_editable_region: mut output_text, range_in_excerpt: editable_range_in_excerpt, + cursor_offset_in_new_editable_region: cursor_offset_in_output, }) = output else { return Ok((Some((request_id, None)), None)); @@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta( .text_for_range(editable_range_in_buffer.clone()) .collect::(); - // Client-side cursor marker processing (applies to both raw and v3 responses) - let cursor_offset_in_output = output_text.find(CURSOR_MARKER); - if let Some(offset) = cursor_offset_in_output { - log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}"); - output_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - } - if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(DebugEvent::EditPredictionFinished( diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 2b41384e176ac7a6cc5c3dc7f93ddbba3cf027ae..fc85afa371a4edfe8080d602000c38ecedb98c86 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -5,8 +5,7 @@ use crate::{ repair, }; use anyhow::{Context as _, Result}; -use edit_prediction::example_spec::encode_cursor_in_patch; -use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output}; +use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch}; pub fn run_parse_output(example: &mut Example) -> Result<()> { example @@ -65,46 +64,18 @@ fn parse_zeta2_output( .context("prompt_inputs required")?; let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?; - let range_in_excerpt = parsed.range_in_excerpt; - + let range_in_excerpt = parsed.range_in_excerpt.clone(); let excerpt = prompt_inputs.cursor_excerpt.as_ref(); - let old_text = excerpt[range_in_excerpt.clone()].to_string(); - let mut new_text = parsed.new_editable_region; - - let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) { - new_text.replace_range(offset..offset + CURSOR_MARKER.len(), ""); - Some(offset) - } else { - None - }; + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - // Normalize trailing newlines for diff generation - let mut old_text_normalized = old_text; + let mut new_text = parsed.new_editable_region.clone(); if !new_text.is_empty() && !new_text.ends_with('\n') { new_text.push('\n'); } - if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { - old_text_normalized.push('\n'); - } - - let editable_region_offset = range_in_excerpt.start; - let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - let editable_region_lines = old_text_normalized.lines().count() as u32; - - let diff = language::unified_diff_with_context( - &old_text_normalized, - &new_text, - editable_region_start_line as u32, - editable_region_start_line as u32, - editable_region_lines, - ); - - let formatted_diff = format!( - "--- a/{path}\n+++ b/{path}\n{diff}", - path = example.spec.cursor_path.to_string_lossy(), - ); - let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset); + let cursor_offset = parsed.cursor_offset_in_new_editable_region; + let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?; let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| { ActualCursor::from_editable_region( diff --git a/crates/zeta_prompt/Cargo.toml b/crates/zeta_prompt/Cargo.toml index 21634583d33e13cd9570041f3e8466d05cef9944..8acd91a7a43613fd63f4f46ab73e9485fd64e7d2 100644 --- a/crates/zeta_prompt/Cargo.toml +++ b/crates/zeta_prompt/Cargo.toml @@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs" [dependencies] anyhow.workspace = true +imara-diff.workspace = true serde.workspace = true strum.workspace = true diff --git a/crates/zeta_prompt/src/udiff.rs b/crates/zeta_prompt/src/udiff.rs index 2658da5893ee923dc0f5798554276f5735abb51a..ab0837b9f54ac0bf9ef74038f0c876b751f70200 100644 --- a/crates/zeta_prompt/src/udiff.rs +++ b/crates/zeta_prompt/src/udiff.rs @@ -6,6 +6,10 @@ use std::{ }; use anyhow::{Context as _, Result, anyhow}; +use imara_diff::{ + Algorithm, Sink, diff, + intern::{InternedInput, Interner, Token}, +}; pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> { if prefix.is_empty() { @@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number( } } +pub fn unified_diff_with_context( + old_text: &str, + new_text: &str, + old_start_line: u32, + new_start_line: u32, + context_lines: u32, +) -> String { + let input = InternedInput::new(old_text, new_text); + diff( + Algorithm::Histogram, + &input, + OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines), + ) +} + +struct OffsetUnifiedDiffBuilder<'a> { + before: &'a [Token], + after: &'a [Token], + interner: &'a Interner<&'a str>, + pos: u32, + before_hunk_start: u32, + after_hunk_start: u32, + before_hunk_len: u32, + after_hunk_len: u32, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + buffer: String, + dst: String, +} + +impl<'a> OffsetUnifiedDiffBuilder<'a> { + fn new( + input: &'a InternedInput<&'a str>, + old_line_offset: u32, + new_line_offset: u32, + context_lines: u32, + ) -> Self { + Self { + before_hunk_start: 0, + after_hunk_start: 0, + before_hunk_len: 0, + after_hunk_len: 0, + old_line_offset, + new_line_offset, + context_lines, + buffer: String::with_capacity(8), + dst: String::new(), + interner: &input.interner, + before: &input.before, + after: &input.after, + pos: 0, + } + } + + fn print_tokens(&mut self, tokens: &[Token], prefix: char) { + for &token in tokens { + writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap(); + } + } + + fn flush(&mut self) { + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + return; + } + + let end = (self.pos + self.context_lines).min(self.before.len() as u32); + self.update_pos(end, end); + + writeln!( + &mut self.dst, + "@@ -{},{} +{},{} @@", + self.before_hunk_start + 1 + self.old_line_offset, + self.before_hunk_len, + self.after_hunk_start + 1 + self.new_line_offset, + self.after_hunk_len, + ) + .unwrap(); + write!(&mut self.dst, "{}", &self.buffer).unwrap(); + self.buffer.clear(); + self.before_hunk_len = 0; + self.after_hunk_len = 0; + } + + fn update_pos(&mut self, print_to: u32, move_to: u32) { + self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' '); + let len = print_to - self.pos; + self.before_hunk_len += len; + self.after_hunk_len += len; + self.pos = move_to; + } +} + +impl Sink for OffsetUnifiedDiffBuilder<'_> { + type Out = String; + + fn process_change(&mut self, before: Range, after: Range) { + if before.start - self.pos > self.context_lines * 2 { + self.flush(); + } + if self.before_hunk_len == 0 && self.after_hunk_len == 0 { + self.pos = before.start.saturating_sub(self.context_lines); + self.before_hunk_start = self.pos; + self.after_hunk_start = after.start.saturating_sub(self.context_lines); + } + + self.update_pos(before.start, before.end); + self.before_hunk_len += before.end - before.start; + self.after_hunk_len += after.end - after.start; + self.print_tokens( + &self.before[before.start as usize..before.end as usize], + '-', + ); + self.print_tokens(&self.after[after.start as usize..after.end as usize], '+'); + } + + fn finish(mut self) -> Self::Out { + self.flush(); + self.dst + } +} + +pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option) -> String { + let Some(cursor_offset) = cursor_offset else { + return patch.to_string(); + }; + + let mut result = String::new(); + let mut line_start_offset = 0usize; + + for line in patch.lines() { + if matches!( + DiffLine::parse(line), + DiffLine::Garbage(content) + if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER) + ) { + continue; + } + + if !result.is_empty() { + result.push('\n'); + } + result.push_str(line); + + match DiffLine::parse(line) { + DiffLine::Addition(content) => { + let line_end_offset = line_start_offset + content.len(); + + if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset { + let cursor_column = cursor_offset - line_start_offset; + + result.push('\n'); + result.push('#'); + for _ in 0..cursor_column { + result.push(' '); + } + write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap(); + } + + line_start_offset = line_end_offset + 1; + } + DiffLine::Context(content) => { + line_start_offset += content.len() + 1; + } + _ => {} + } + } + + if patch.ends_with('\n') { + result.push('\n'); + } + + result +} + pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text) } @@ -1203,4 +1382,25 @@ mod tests { // Edit range end should be clamped to 7 (new context length). assert_eq!(hunk.edits[0].range, 4..7); } + + #[test] + fn test_unified_diff_with_context_matches_expected_context_window() { + let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n"; + let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n"; + + let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3); + assert_eq!( + diff_default, + "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8); + assert_eq!( + diff_full_context, + "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n" + ); + + let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0); + assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n"); + } } diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 0d72d6cd7a46782aa4b572a4ef564d5fe3dec417..49b86404a8ad49c27e29bb2b887fb3fc8171c35c 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat { impl ZetaFormat { pub fn parse(format_name: &str) -> Result { + let lower = format_name.to_lowercase(); + + // Exact case-insensitive match takes priority, bypassing ambiguity checks. + for variant in ZetaFormat::iter() { + if <&'static str>::from(&variant).to_lowercase() == lower { + return Ok(variant); + } + } + let mut results = ZetaFormat::iter().filter(|version| { <&'static str>::from(version) .to_lowercase() - .contains(&format_name.to_lowercase()) + .contains(&lower) }); let Some(result) = results.next() else { anyhow::bail!( @@ -927,11 +936,39 @@ fn cursor_in_new_text( }) } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct ParsedOutput { /// Text that should replace the editable region pub new_editable_region: String, /// The byte range within `cursor_excerpt` that this replacement applies to pub range_in_excerpt: Range, + /// Byte offset of the cursor marker within `new_editable_region`, if present + pub cursor_offset_in_new_editable_region: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct CursorPosition { + pub path: String, + pub row: usize, + pub column: usize, + pub offset: usize, + pub editable_region_offset: usize, +} + +pub fn parsed_output_from_editable_region( + range_in_excerpt: Range, + mut new_editable_region: String, +) -> ParsedOutput { + let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER); + if let Some(offset) = cursor_offset_in_new_editable_region { + new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), ""); + } + + ParsedOutput { + new_editable_region, + range_in_excerpt, + cursor_offset_in_new_editable_region, + } } /// Parse model output for the given zeta format @@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output( let range_in_excerpt = range_in_context.start + context_start..range_in_context.end + context_start; - Ok(ParsedOutput { - new_editable_region: output, - range_in_excerpt, + Ok(parsed_output_from_editable_region(range_in_excerpt, output)) +} + +pub fn parse_zeta2_model_output_as_patch( + output: &str, + format: ZetaFormat, + prompt_inputs: &ZetaPromptInput, +) -> Result { + let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?; + parsed_output_to_patch(prompt_inputs, parsed) +} + +pub fn cursor_position_from_parsed_output( + prompt_inputs: &ZetaPromptInput, + parsed: &ParsedOutput, +) -> Option { + let cursor_offset = parsed.cursor_offset_in_new_editable_region?; + let editable_region_offset = parsed.range_in_excerpt.start; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); + + let new_editable_region = &parsed.new_editable_region; + let prefix_end = cursor_offset.min(new_editable_region.len()); + let new_region_prefix = &new_editable_region[..prefix_end]; + + let row = editable_region_start_line + new_region_prefix.matches('\n').count(); + + let column = match new_region_prefix.rfind('\n') { + Some(last_newline) => cursor_offset - last_newline - 1, + None => { + let content_prefix = &excerpt[..editable_region_offset]; + let content_column = match content_prefix.rfind('\n') { + Some(last_newline) => editable_region_offset - last_newline - 1, + None => editable_region_offset, + }; + content_column + cursor_offset + } + }; + + Some(CursorPosition { + path: prompt_inputs.cursor_path.to_string_lossy().into_owned(), + row, + column, + offset: editable_region_offset + cursor_offset, + editable_region_offset: cursor_offset, }) } +pub fn parsed_output_to_patch( + prompt_inputs: &ZetaPromptInput, + parsed: ParsedOutput, +) -> Result { + let range_in_excerpt = parsed.range_in_excerpt; + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let old_text = excerpt[range_in_excerpt.clone()].to_string(); + let mut new_text = parsed.new_editable_region; + + let mut old_text_normalized = old_text; + if !new_text.is_empty() && !new_text.ends_with('\n') { + new_text.push('\n'); + } + if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') { + old_text_normalized.push('\n'); + } + + let editable_region_offset = range_in_excerpt.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32; + let editable_region_lines = old_text_normalized.lines().count() as u32; + + let diff = udiff::unified_diff_with_context( + &old_text_normalized, + &new_text, + editable_region_start_line, + editable_region_start_line, + editable_region_lines, + ); + + let path = prompt_inputs + .cursor_path + .to_string_lossy() + .trim_start_matches('/') + .to_string(); + let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}"); + + Ok(udiff::encode_cursor_in_patch( + &formatted_diff, + parsed.cursor_offset_in_new_editable_region, + )) +} + pub fn excerpt_range_for_format( format: ZetaFormat, ranges: &ExcerptRanges, @@ -5400,6 +5522,33 @@ mod tests { assert_eq!(apply_edit(excerpt, &output1), "new content\n"); } + #[test] + fn test_parsed_output_to_patch_round_trips_through_udiff_application() { + let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n"; + let context_start = excerpt.find("ctx start").unwrap(); + let context_end = excerpt.find("after ctx").unwrap(); + let editable_start = excerpt.find("editable old").unwrap(); + let editable_end = editable_start + "editable old\n".len(); + let input = make_input_with_context_range( + excerpt, + editable_start..editable_end, + context_start..context_end, + editable_start, + ); + + let parsed = parse_zeta2_model_output( + "editable new\n>>>>>>> UPDATED\n", + ZetaFormat::V0131GitMergeMarkersPrefix, + &input, + ) + .unwrap(); + let expected = apply_edit(excerpt, &parsed); + let patch = parsed_output_to_patch(&input, parsed).unwrap(); + let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap(); + + assert_eq!(patched, expected); + } + #[test] fn test_special_tokens_not_triggered_by_comment_separator() { // Regression test for https://github.com/zed-industries/zed/issues/52489