ep: Unify code in `parse_output` and `zeta_prompt` (#50958)

Oleksiy Syvokon created

This also fixed `ep parse-output` for newer formats

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/zeta.rs             | 30 +++++++------
crates/edit_prediction_cli/src/parse_output.rs | 37 +++++------------
crates/zeta_prompt/src/zeta_prompt.rs          | 41 ++++++++++++-------
3 files changed, 53 insertions(+), 55 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta.rs 🔗

@@ -19,7 +19,7 @@ use settings::EditPredictionPromptFormat;
 use text::{Anchor, Bias};
 use ui::SharedString;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
-use zeta_prompt::ZetaPromptInput;
+use zeta_prompt::{ParsedOutput, ZetaPromptInput};
 
 use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 use zeta_prompt::{
@@ -175,13 +175,12 @@ pub fn request_prediction_with_zeta(
 
                             let request_id = EditPredictionId(request_id.into());
                             let output_text = zeta1::clean_zeta1_model_output(&response_text);
+                            let parsed_output = output_text.map(|text| ParsedOutput {
+                                new_editable_region: text,
+                                range_in_excerpt: editable_range_in_excerpt,
+                            });
 
-                            (
-                                request_id,
-                                Some(editable_range_in_excerpt).zip(output_text),
-                                None,
-                                None,
-                            )
+                            (request_id, parsed_output, None, None)
                         }
                         EditPredictionPromptFormat::Zeta2 => {
                             let prompt = format_zeta_prompt(&prompt_input, zeta_version);
@@ -271,20 +270,23 @@ pub fn request_prediction_with_zeta(
                     let request_id = EditPredictionId(response.request_id.into());
                     let output_text = Some(response.output).filter(|s| !s.is_empty());
                     let model_version = response.model_version;
+                    let parsed_output = ParsedOutput {
+                        new_editable_region: output_text.unwrap_or_default(),
+                        range_in_excerpt: response.editable_range,
+                    };
 
-                    (
-                        request_id,
-                        Some(response.editable_range).zip(output_text),
-                        model_version,
-                        usage,
-                    )
+                    (request_id, Some(parsed_output), model_version, usage)
                 };
 
             let received_response_at = Instant::now();
 
             log::trace!("Got edit prediction response");
 
-            let Some((editable_range_in_excerpt, mut output_text)) = output else {
+            let Some(ParsedOutput {
+                new_editable_region: mut output_text,
+                range_in_excerpt: editable_range_in_excerpt,
+            }) = output
+            else {
                 return Ok(((request_id, None), None));
             };
 

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -6,7 +6,7 @@ use crate::{
 };
 use anyhow::{Context as _, Result};
 use edit_prediction::example_spec::encode_cursor_in_patch;
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, output_end_marker_for_format, resolve_cursor_region};
+use zeta_prompt::{CURSOR_MARKER, ZetaFormat, parse_zeta2_model_output};
 
 pub fn run_parse_output(example: &mut Example) -> Result<()> {
     example
@@ -60,10 +60,13 @@ fn parse_zeta2_output(
         .as_ref()
         .context("prompt_inputs required")?;
 
-    let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format);
-    let old_text = context[editable_range].to_string();
+    let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
+    let range_in_excerpt = parsed.range_in_excerpt;
+
+    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
+    let old_text = excerpt[range_in_excerpt.clone()].to_string();
+    let mut new_text = parsed.new_editable_region;
 
-    let mut new_text = actual_output.to_string();
     let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
         new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
         Some(offset)
@@ -71,14 +74,8 @@ fn parse_zeta2_output(
         None
     };
 
-    if let Some(marker) = output_end_marker_for_format(format) {
-        new_text = new_text
-            .strip_suffix(marker)
-            .unwrap_or(&new_text)
-            .to_string();
-    }
-
-    let mut old_text_normalized = old_text.clone();
+    // Normalize trailing newlines for diff generation
+    let mut old_text_normalized = old_text;
     if !new_text.is_empty() && !new_text.ends_with('\n') {
         new_text.push('\n');
     }
@@ -86,22 +83,10 @@ fn parse_zeta2_output(
         old_text_normalized.push('\n');
     }
 
-    let old_text_trimmed = old_text.trim_end_matches('\n');
-    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
-    let (editable_region_offset, _) = excerpt
-        .match_indices(old_text_trimmed)
-        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
-        .with_context(|| {
-            format!(
-                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
-                old_text_trimmed, excerpt
-            )
-        })?;
-
+    let editable_region_offset = range_in_excerpt.start;
     let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
-
-    // Use full context so cursor offset (relative to editable region start) aligns with diff content
     let editable_region_lines = old_text_normalized.lines().count() as u32;
+
     let diff = language::unified_diff_with_context(
         &old_text_normalized,
         &new_text,

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -470,12 +470,19 @@ pub fn encode_patch_as_output_for_format(
     }
 }
 
+pub struct ParsedOutput {
+    /// Text that should replace the editable region
+    pub new_editable_region: String,
+    /// The byte range within `cursor_excerpt` that this replacement applies to
+    pub range_in_excerpt: Range<usize>,
+}
+
 /// Parse model output for the given zeta format
 pub fn parse_zeta2_model_output(
     output: &str,
     format: ZetaFormat,
     prompt_inputs: &ZetaPromptInput,
-) -> Result<(Range<usize>, String)> {
+) -> Result<ParsedOutput> {
     let output = match output_end_marker_for_format(format) {
         Some(marker) => output.strip_suffix(marker).unwrap_or(output),
         None => output,
@@ -509,7 +516,11 @@ pub fn parse_zeta2_model_output(
 
     let range_in_excerpt =
         range_in_context.start + context_start..range_in_context.end + context_start;
-    Ok((range_in_excerpt, output))
+
+    Ok(ParsedOutput {
+        new_editable_region: output,
+        range_in_excerpt,
+    })
 }
 
 pub fn excerpt_range_for_format(
@@ -4612,9 +4623,12 @@ mod tests {
         assert_eq!(cleaned, "");
     }
 
-    fn apply_edit(excerpt: &str, range: &Range<usize>, new_text: &str) -> String {
+    fn apply_edit(excerpt: &str, parsed_output: &ParsedOutput) -> String {
         let mut result = excerpt.to_string();
-        result.replace_range(range.clone(), new_text);
+        result.replace_range(
+            parsed_output.range_in_excerpt.clone(),
+            &parsed_output.new_editable_region,
+        );
         result
     }
 
@@ -4632,7 +4646,7 @@ mod tests {
             editable_start,
         );
 
-        let (range, text) = parse_zeta2_model_output(
+        let output = parse_zeta2_model_output(
             "editable new\n>>>>>>> UPDATED\n",
             ZetaFormat::V0131GitMergeMarkersPrefix,
             &input,
@@ -4640,7 +4654,7 @@ mod tests {
         .unwrap();
 
         assert_eq!(
-            apply_edit(excerpt, &range, &text),
+            apply_edit(excerpt, &output),
             "before ctx\nctx start\neditable new\nctx end\nafter ctx\n"
         );
     }
@@ -4658,10 +4672,10 @@ mod tests {
         );
 
         let format = ZetaFormat::V0131GitMergeMarkersPrefix;
-        let (range, text) =
+        let output =
             parse_zeta2_model_output("bbb\nccc\n>>>>>>> UPDATED\n", format, &input).unwrap();
 
-        assert_eq!(apply_edit(excerpt, &range, &text), excerpt);
+        assert_eq!(apply_edit(excerpt, &output), excerpt);
     }
 
     #[test]
@@ -4670,14 +4684,11 @@ mod tests {
         let input = make_input_with_context_range(excerpt, 0..excerpt.len(), 0..excerpt.len(), 0);
 
         let format = ZetaFormat::V0131GitMergeMarkersPrefix;
-        let (range1, text1) =
+        let output1 =
             parse_zeta2_model_output("new content\n>>>>>>> UPDATED\n", format, &input).unwrap();
-        let (range2, text2) = parse_zeta2_model_output("new content\n", format, &input).unwrap();
+        let output2 = parse_zeta2_model_output("new content\n", format, &input).unwrap();
 
-        assert_eq!(
-            apply_edit(excerpt, &range1, &text1),
-            apply_edit(excerpt, &range2, &text2)
-        );
-        assert_eq!(apply_edit(excerpt, &range1, &text1), "new content\n");
+        assert_eq!(apply_edit(excerpt, &output1), apply_edit(excerpt, &output2));
+        assert_eq!(apply_edit(excerpt, &output1), "new content\n");
     }
 }