Cargo.lock 🔗
@@ -22542,6 +22542,7 @@ name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
+ "imara-diff",
"indoc",
"serde",
"strum 0.27.2",
Ben Kunkle created
Self-Review Checklist:
- [x] I've reviewed my own diff for quality, security, and reliability
- [ ] Unsafe blocks (if any) have justifying comments
- [ ] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable
Closes #ISSUE
Release Notes:
- N/A or Added/Fixed/Improved ...
Cargo.lock | 1
crates/edit_prediction/src/edit_prediction_tests.rs | 59 ++++
crates/edit_prediction/src/example_spec.rs | 111 --------
crates/edit_prediction/src/zeta.rs | 22
crates/edit_prediction_cli/src/parse_output.rs | 43 --
crates/zeta_prompt/Cargo.toml | 1
crates/zeta_prompt/src/udiff.rs | 200 +++++++++++++++
crates/zeta_prompt/src/zeta_prompt.rs | 157 +++++++++++
8 files changed, 434 insertions(+), 160 deletions(-)
@@ -22542,6 +22542,7 @@ name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
+ "imara-diff",
"indoc",
"serde",
"strum 0.27.2",
@@ -2707,6 +2707,65 @@ async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppConte
});
}
+#[gpui::test]
+async fn test_v3_prediction_strips_cursor_marker_from_edit_text(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.txt": "hello"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project
+ .find_project_path(path!("root/foo.txt"), cx)
+ .unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(0, 5));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let excerpt_length = request.input.cursor_excerpt.len();
+ respond_tx
+ .send(PredictEditsV3Response {
+ request_id: Uuid::new_v4().to_string(),
+ output: "hello<|user_cursor|> world".to_string(),
+ editable_range: 0..excerpt_length,
+ model_version: None,
+ })
+ .unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .prediction_at(&buffer, None, &project, cx)
+ .expect("should have prediction");
+ let snapshot = buffer.read(cx).snapshot();
+ let edits: Vec<_> = prediction
+ .edits
+ .iter()
+ .map(|(range, text)| (range.to_offset(&snapshot), text.clone()))
+ .collect();
+
+ assert_eq!(edits, vec![(5..5, " world".into())]);
+ });
+}
+
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -1,10 +1,11 @@
-use crate::udiff::DiffLine;
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
use telemetry_events::EditPredictionRating;
-pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
+pub use zeta_prompt::udiff::{
+ CURSOR_POSITION_MARKER, encode_cursor_in_patch, extract_cursor_from_patch,
+};
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
/// Maximum cursor file size to capture (64KB).
@@ -12,64 +13,6 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
/// falling back to git-based loading.
pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
-/// Encodes a cursor position into a diff patch by adding a comment line with a caret
-/// pointing to the cursor column.
-///
-/// The cursor offset is relative to the start of the new text content (additions and context lines).
-/// Returns the patch with cursor marker comment lines inserted after the relevant addition line.
-pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
- let Some(cursor_offset) = cursor_offset else {
- return patch.to_string();
- };
-
- let mut result = String::new();
- let mut line_start_offset = 0usize;
-
- for line in patch.lines() {
- if matches!(
- DiffLine::parse(line),
- DiffLine::Garbage(content)
- if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
- ) {
- continue;
- }
-
- if !result.is_empty() {
- result.push('\n');
- }
- result.push_str(line);
-
- match DiffLine::parse(line) {
- DiffLine::Addition(content) => {
- let line_end_offset = line_start_offset + content.len();
-
- if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
- let cursor_column = cursor_offset - line_start_offset;
-
- result.push('\n');
- result.push('#');
- for _ in 0..cursor_column {
- result.push(' ');
- }
- write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
- }
-
- line_start_offset = line_end_offset + 1;
- }
- DiffLine::Context(content) => {
- line_start_offset += content.len() + 1;
- }
- _ => {}
- }
- }
-
- if patch.ends_with('\n') {
- result.push('\n');
- }
-
- result
-}
-
#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
@@ -509,53 +452,7 @@ impl ExampleSpec {
pub fn expected_patches_with_cursor_positions(&self) -> Vec<(String, Option<usize>)> {
self.expected_patches
.iter()
- .map(|patch| {
- let mut clean_patch = String::new();
- let mut cursor_offset: Option<usize> = None;
- let mut line_start_offset = 0usize;
- let mut prev_line_start_offset = 0usize;
-
- for line in patch.lines() {
- let diff_line = DiffLine::parse(line);
-
- match &diff_line {
- DiffLine::Garbage(content)
- if content.starts_with('#')
- && content.contains(CURSOR_POSITION_MARKER) =>
- {
- let caret_column = if let Some(caret_pos) = content.find('^') {
- caret_pos
- } else if let Some(_) = content.find('<') {
- 0
- } else {
- continue;
- };
- let cursor_column = caret_column.saturating_sub('#'.len_utf8());
- cursor_offset = Some(prev_line_start_offset + cursor_column);
- }
- _ => {
- if !clean_patch.is_empty() {
- clean_patch.push('\n');
- }
- clean_patch.push_str(line);
-
- match diff_line {
- DiffLine::Addition(content) | DiffLine::Context(content) => {
- prev_line_start_offset = line_start_offset;
- line_start_offset += content.len() + 1;
- }
- _ => {}
- }
- }
- }
- }
-
- if patch.ends_with('\n') && !clean_patch.is_empty() {
- clean_patch.push('\n');
- }
-
- (clean_patch, cursor_offset)
- })
+ .map(|patch| extract_cursor_from_patch(patch))
.collect()
}
@@ -24,8 +24,9 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput};
use std::{env, ops::Range, path::Path, sync::Arc};
use zeta_prompt::{
- CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
- prompt_input_contains_special_tokens, stop_tokens_for_format,
+ ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
+ parsed_output_from_editable_region, prompt_input_contains_special_tokens,
+ stop_tokens_for_format,
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
@@ -181,6 +182,7 @@ pub fn request_prediction_with_zeta(
let parsed_output = output_text.map(|text| ParsedOutput {
new_editable_region: text,
range_in_excerpt: editable_range_in_excerpt,
+ cursor_offset_in_new_editable_region: None,
});
(request_id, parsed_output, None, None)
@@ -283,10 +285,10 @@ pub fn request_prediction_with_zeta(
let request_id = EditPredictionId(response.request_id.into());
let output_text = Some(response.output).filter(|s| !s.is_empty());
let model_version = response.model_version;
- let parsed_output = ParsedOutput {
- new_editable_region: output_text.unwrap_or_default(),
- range_in_excerpt: response.editable_range,
- };
+ let parsed_output = parsed_output_from_editable_region(
+ response.editable_range,
+ output_text.unwrap_or_default(),
+ );
Some((request_id, Some(parsed_output), model_version, usage))
})
@@ -299,6 +301,7 @@ pub fn request_prediction_with_zeta(
let Some(ParsedOutput {
new_editable_region: mut output_text,
range_in_excerpt: editable_range_in_excerpt,
+ cursor_offset_in_new_editable_region: cursor_offset_in_output,
}) = output
else {
return Ok((Some((request_id, None)), None));
@@ -312,13 +315,6 @@ pub fn request_prediction_with_zeta(
.text_for_range(editable_range_in_buffer.clone())
.collect::<String>();
- // Client-side cursor marker processing (applies to both raw and v3 responses)
- let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
- if let Some(offset) = cursor_offset_in_output {
- log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
- output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- }
-
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
@@ -5,8 +5,7 @@ use crate::{
repair,
};
use anyhow::{Context as _, Result};
-use edit_prediction::example_spec::encode_cursor_in_patch;
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output};
+use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch};
pub fn run_parse_output(example: &mut Example) -> Result<()> {
example
@@ -65,46 +64,18 @@ fn parse_zeta2_output(
.context("prompt_inputs required")?;
let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
- let range_in_excerpt = parsed.range_in_excerpt;
-
+ let range_in_excerpt = parsed.range_in_excerpt.clone();
let excerpt = prompt_inputs.cursor_excerpt.as_ref();
- let old_text = excerpt[range_in_excerpt.clone()].to_string();
- let mut new_text = parsed.new_editable_region;
-
- let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
- new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- Some(offset)
- } else {
- None
- };
+ let editable_region_offset = range_in_excerpt.start;
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
- // Normalize trailing newlines for diff generation
- let mut old_text_normalized = old_text;
+ let mut new_text = parsed.new_editable_region.clone();
if !new_text.is_empty() && !new_text.ends_with('\n') {
new_text.push('\n');
}
- if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
- old_text_normalized.push('\n');
- }
-
- let editable_region_offset = range_in_excerpt.start;
- let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
- let editable_region_lines = old_text_normalized.lines().count() as u32;
-
- let diff = language::unified_diff_with_context(
- &old_text_normalized,
- &new_text,
- editable_region_start_line as u32,
- editable_region_start_line as u32,
- editable_region_lines,
- );
-
- let formatted_diff = format!(
- "--- a/{path}\n+++ b/{path}\n{diff}",
- path = example.spec.cursor_path.to_string_lossy(),
- );
- let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
+ let cursor_offset = parsed.cursor_offset_in_new_editable_region;
+ let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?;
let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
ActualCursor::from_editable_region(
@@ -13,6 +13,7 @@ path = "src/zeta_prompt.rs"
[dependencies]
anyhow.workspace = true
+imara-diff.workspace = true
serde.workspace = true
strum.workspace = true
@@ -6,6 +6,10 @@ use std::{
};
use anyhow::{Context as _, Result, anyhow};
+use imara_diff::{
+ Algorithm, Sink, diff,
+ intern::{InternedInput, Interner, Token},
+};
pub fn strip_diff_path_prefix<'a>(diff: &'a str, prefix: &str) -> Cow<'a, str> {
if prefix.is_empty() {
@@ -221,6 +225,181 @@ pub fn disambiguate_by_line_number(
}
}
+pub fn unified_diff_with_context(
+ old_text: &str,
+ new_text: &str,
+ old_start_line: u32,
+ new_start_line: u32,
+ context_lines: u32,
+) -> String {
+ let input = InternedInput::new(old_text, new_text);
+ diff(
+ Algorithm::Histogram,
+ &input,
+ OffsetUnifiedDiffBuilder::new(&input, old_start_line, new_start_line, context_lines),
+ )
+}
+
+struct OffsetUnifiedDiffBuilder<'a> {
+ before: &'a [Token],
+ after: &'a [Token],
+ interner: &'a Interner<&'a str>,
+ pos: u32,
+ before_hunk_start: u32,
+ after_hunk_start: u32,
+ before_hunk_len: u32,
+ after_hunk_len: u32,
+ old_line_offset: u32,
+ new_line_offset: u32,
+ context_lines: u32,
+ buffer: String,
+ dst: String,
+}
+
+impl<'a> OffsetUnifiedDiffBuilder<'a> {
+ fn new(
+ input: &'a InternedInput<&'a str>,
+ old_line_offset: u32,
+ new_line_offset: u32,
+ context_lines: u32,
+ ) -> Self {
+ Self {
+ before_hunk_start: 0,
+ after_hunk_start: 0,
+ before_hunk_len: 0,
+ after_hunk_len: 0,
+ old_line_offset,
+ new_line_offset,
+ context_lines,
+ buffer: String::with_capacity(8),
+ dst: String::new(),
+ interner: &input.interner,
+ before: &input.before,
+ after: &input.after,
+ pos: 0,
+ }
+ }
+
+ fn print_tokens(&mut self, tokens: &[Token], prefix: char) {
+ for &token in tokens {
+ writeln!(&mut self.buffer, "{prefix}{}", self.interner[token]).unwrap();
+ }
+ }
+
+ fn flush(&mut self) {
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ return;
+ }
+
+ let end = (self.pos + self.context_lines).min(self.before.len() as u32);
+ self.update_pos(end, end);
+
+ writeln!(
+ &mut self.dst,
+ "@@ -{},{} +{},{} @@",
+ self.before_hunk_start + 1 + self.old_line_offset,
+ self.before_hunk_len,
+ self.after_hunk_start + 1 + self.new_line_offset,
+ self.after_hunk_len,
+ )
+ .unwrap();
+ write!(&mut self.dst, "{}", &self.buffer).unwrap();
+ self.buffer.clear();
+ self.before_hunk_len = 0;
+ self.after_hunk_len = 0;
+ }
+
+ fn update_pos(&mut self, print_to: u32, move_to: u32) {
+ self.print_tokens(&self.before[self.pos as usize..print_to as usize], ' ');
+ let len = print_to - self.pos;
+ self.before_hunk_len += len;
+ self.after_hunk_len += len;
+ self.pos = move_to;
+ }
+}
+
+impl Sink for OffsetUnifiedDiffBuilder<'_> {
+ type Out = String;
+
+ fn process_change(&mut self, before: Range<u32>, after: Range<u32>) {
+ if before.start - self.pos > self.context_lines * 2 {
+ self.flush();
+ }
+ if self.before_hunk_len == 0 && self.after_hunk_len == 0 {
+ self.pos = before.start.saturating_sub(self.context_lines);
+ self.before_hunk_start = self.pos;
+ self.after_hunk_start = after.start.saturating_sub(self.context_lines);
+ }
+
+ self.update_pos(before.start, before.end);
+ self.before_hunk_len += before.end - before.start;
+ self.after_hunk_len += after.end - after.start;
+ self.print_tokens(
+ &self.before[before.start as usize..before.end as usize],
+ '-',
+ );
+ self.print_tokens(&self.after[after.start as usize..after.end as usize], '+');
+ }
+
+ fn finish(mut self) -> Self::Out {
+ self.flush();
+ self.dst
+ }
+}
+
+pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
+ let Some(cursor_offset) = cursor_offset else {
+ return patch.to_string();
+ };
+
+ let mut result = String::new();
+ let mut line_start_offset = 0usize;
+
+ for line in patch.lines() {
+ if matches!(
+ DiffLine::parse(line),
+ DiffLine::Garbage(content)
+ if content.starts_with('#') && content.contains(CURSOR_POSITION_MARKER)
+ ) {
+ continue;
+ }
+
+ if !result.is_empty() {
+ result.push('\n');
+ }
+ result.push_str(line);
+
+ match DiffLine::parse(line) {
+ DiffLine::Addition(content) => {
+ let line_end_offset = line_start_offset + content.len();
+
+ if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
+ let cursor_column = cursor_offset - line_start_offset;
+
+ result.push('\n');
+ result.push('#');
+ for _ in 0..cursor_column {
+ result.push(' ');
+ }
+ write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
+ }
+
+ line_start_offset = line_end_offset + 1;
+ }
+ DiffLine::Context(content) => {
+ line_start_offset += content.len() + 1;
+ }
+ _ => {}
+ }
+ }
+
+ if patch.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result
+}
+
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
apply_diff_to_string_with_hunk_offset(diff_str, text).map(|(text, _)| text)
}
@@ -1203,4 +1382,25 @@ mod tests {
// Edit range end should be clamped to 7 (new context length).
assert_eq!(hunk.edits[0].range, 4..7);
}
+
+ #[test]
+ fn test_unified_diff_with_context_matches_expected_context_window() {
+ let old_text = "line1\nline2\nline3\nline4\nline5\nCHANGE_ME\nline7\nline8\n";
+ let new_text = "line1\nline2\nline3\nline4\nline5\nCHANGED\nline7\nline8\n";
+
+ let diff_default = unified_diff_with_context(old_text, new_text, 0, 0, 3);
+ assert_eq!(
+ diff_default,
+ "@@ -3,6 +3,6 @@\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ let diff_full_context = unified_diff_with_context(old_text, new_text, 0, 0, 8);
+ assert_eq!(
+ diff_full_context,
+ "@@ -1,8 +1,8 @@\n line1\n line2\n line3\n line4\n line5\n-CHANGE_ME\n+CHANGED\n line7\n line8\n"
+ );
+
+ let diff_no_context = unified_diff_with_context(old_text, new_text, 0, 0, 0);
+ assert_eq!(diff_no_context, "@@ -6,1 +6,1 @@\n-CHANGE_ME\n+CHANGED\n");
+ }
}
@@ -106,10 +106,19 @@ impl std::fmt::Display for ZetaFormat {
impl ZetaFormat {
pub fn parse(format_name: &str) -> Result<Self> {
+ let lower = format_name.to_lowercase();
+
+ // Exact case-insensitive match takes priority, bypassing ambiguity checks.
+ for variant in ZetaFormat::iter() {
+ if <&'static str>::from(&variant).to_lowercase() == lower {
+ return Ok(variant);
+ }
+ }
+
let mut results = ZetaFormat::iter().filter(|version| {
<&'static str>::from(version)
.to_lowercase()
- .contains(&format_name.to_lowercase())
+ .contains(&lower)
});
let Some(result) = results.next() else {
anyhow::bail!(
@@ -927,11 +936,39 @@ fn cursor_in_new_text(
})
}
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ParsedOutput {
/// Text that should replace the editable region
pub new_editable_region: String,
/// The byte range within `cursor_excerpt` that this replacement applies to
pub range_in_excerpt: Range<usize>,
+ /// Byte offset of the cursor marker within `new_editable_region`, if present
+ pub cursor_offset_in_new_editable_region: Option<usize>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+pub struct CursorPosition {
+ pub path: String,
+ pub row: usize,
+ pub column: usize,
+ pub offset: usize,
+ pub editable_region_offset: usize,
+}
+
+pub fn parsed_output_from_editable_region(
+ range_in_excerpt: Range<usize>,
+ mut new_editable_region: String,
+) -> ParsedOutput {
+ let cursor_offset_in_new_editable_region = new_editable_region.find(CURSOR_MARKER);
+ if let Some(offset) = cursor_offset_in_new_editable_region {
+ new_editable_region.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+ }
+
+ ParsedOutput {
+ new_editable_region,
+ range_in_excerpt,
+ cursor_offset_in_new_editable_region,
+ }
}
/// Parse model output for the given zeta format
@@ -999,12 +1036,97 @@ pub fn parse_zeta2_model_output(
let range_in_excerpt =
range_in_context.start + context_start..range_in_context.end + context_start;
- Ok(ParsedOutput {
- new_editable_region: output,
- range_in_excerpt,
+ Ok(parsed_output_from_editable_region(range_in_excerpt, output))
+}
+
+pub fn parse_zeta2_model_output_as_patch(
+ output: &str,
+ format: ZetaFormat,
+ prompt_inputs: &ZetaPromptInput,
+) -> Result<String> {
+ let parsed = parse_zeta2_model_output(output, format, prompt_inputs)?;
+ parsed_output_to_patch(prompt_inputs, parsed)
+}
+
+pub fn cursor_position_from_parsed_output(
+ prompt_inputs: &ZetaPromptInput,
+ parsed: &ParsedOutput,
+) -> Option<CursorPosition> {
+ let cursor_offset = parsed.cursor_offset_in_new_editable_region?;
+ let editable_region_offset = parsed.range_in_excerpt.start;
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
+
+ let new_editable_region = &parsed.new_editable_region;
+ let prefix_end = cursor_offset.min(new_editable_region.len());
+ let new_region_prefix = &new_editable_region[..prefix_end];
+
+ let row = editable_region_start_line + new_region_prefix.matches('\n').count();
+
+ let column = match new_region_prefix.rfind('\n') {
+ Some(last_newline) => cursor_offset - last_newline - 1,
+ None => {
+ let content_prefix = &excerpt[..editable_region_offset];
+ let content_column = match content_prefix.rfind('\n') {
+ Some(last_newline) => editable_region_offset - last_newline - 1,
+ None => editable_region_offset,
+ };
+ content_column + cursor_offset
+ }
+ };
+
+ Some(CursorPosition {
+ path: prompt_inputs.cursor_path.to_string_lossy().into_owned(),
+ row,
+ column,
+ offset: editable_region_offset + cursor_offset,
+ editable_region_offset: cursor_offset,
})
}
+pub fn parsed_output_to_patch(
+ prompt_inputs: &ZetaPromptInput,
+ parsed: ParsedOutput,
+) -> Result<String> {
+ let range_in_excerpt = parsed.range_in_excerpt;
+ let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+ let old_text = excerpt[range_in_excerpt.clone()].to_string();
+ let mut new_text = parsed.new_editable_region;
+
+ let mut old_text_normalized = old_text;
+ if !new_text.is_empty() && !new_text.ends_with('\n') {
+ new_text.push('\n');
+ }
+ if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
+ old_text_normalized.push('\n');
+ }
+
+ let editable_region_offset = range_in_excerpt.start;
+ let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count() as u32;
+ let editable_region_lines = old_text_normalized.lines().count() as u32;
+
+ let diff = udiff::unified_diff_with_context(
+ &old_text_normalized,
+ &new_text,
+ editable_region_start_line,
+ editable_region_start_line,
+ editable_region_lines,
+ );
+
+ let path = prompt_inputs
+ .cursor_path
+ .to_string_lossy()
+ .trim_start_matches('/')
+ .to_string();
+ let formatted_diff = format!("--- a/{path}\n+++ b/{path}\n{diff}");
+
+ Ok(udiff::encode_cursor_in_patch(
+ &formatted_diff,
+ parsed.cursor_offset_in_new_editable_region,
+ ))
+}
+
pub fn excerpt_range_for_format(
format: ZetaFormat,
ranges: &ExcerptRanges,
@@ -5400,6 +5522,33 @@ mod tests {
assert_eq!(apply_edit(excerpt, &output1), "new content\n");
}
+ #[test]
+ fn test_parsed_output_to_patch_round_trips_through_udiff_application() {
+ let excerpt = "before ctx\nctx start\neditable old\nctx end\nafter ctx\n";
+ let context_start = excerpt.find("ctx start").unwrap();
+ let context_end = excerpt.find("after ctx").unwrap();
+ let editable_start = excerpt.find("editable old").unwrap();
+ let editable_end = editable_start + "editable old\n".len();
+ let input = make_input_with_context_range(
+ excerpt,
+ editable_start..editable_end,
+ context_start..context_end,
+ editable_start,
+ );
+
+ let parsed = parse_zeta2_model_output(
+ "editable new\n>>>>>>> UPDATED\n",
+ ZetaFormat::V0131GitMergeMarkersPrefix,
+ &input,
+ )
+ .unwrap();
+ let expected = apply_edit(excerpt, &parsed);
+ let patch = parsed_output_to_patch(&input, parsed).unwrap();
+ let patched = udiff::apply_diff_to_string(&patch, excerpt).unwrap();
+
+ assert_eq!(patched, expected);
+ }
+
#[test]
fn test_special_tokens_not_triggered_by_comment_separator() {
// Regression test for https://github.com/zed-industries/zed/issues/52489