Add `ep parse-output` command (#47220)

Oleksiy Syvokon created

This command takes raw LLM outputs (`predictions.actual_output`) that
could be generated elsewhere and parses them into a canonical unified
diff (`predictions.actual_patch`).

This is useful for simplifying the evaluation pipeline and for rerunning
the parser without having to generate LLM outputs.

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/distill.rs      |   2 
crates/edit_prediction_cli/src/example.rs      |   3 
crates/edit_prediction_cli/src/main.rs         |   8 
crates/edit_prediction_cli/src/parse_output.rs | 234 ++++++++++++++++++++
crates/edit_prediction_cli/src/predict.rs      |  20 
crates/edit_prediction_cli/src/score.rs        |  23 +
6 files changed, 276 insertions(+), 14 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/distill.rs 🔗

@@ -6,7 +6,7 @@ use crate::example::Example;
 pub async fn run_distill(example: &mut Example) -> Result<()> {
     let predictions = mem::take(&mut example.predictions)
         .into_iter()
-        .map(|p| p.actual_patch)
+        .filter_map(|p| p.actual_patch)
         .collect();
 
     example.spec.expected_patches = predictions;

crates/edit_prediction_cli/src/example.rs 🔗

@@ -73,7 +73,8 @@ pub struct ExamplePrompt {
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ExamplePrediction {
-    pub actual_patch: String,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub actual_patch: Option<String>,
     pub actual_output: String,
     pub provider: PredictionProvider,
 }

crates/edit_prediction_cli/src/main.rs 🔗

@@ -6,6 +6,7 @@ mod git;
 mod headless;
 mod load_project;
 mod metrics;
+mod parse_output;
 mod paths;
 mod predict;
 mod progress;
@@ -130,6 +131,9 @@ enum Command {
     FormatPrompt(FormatPromptArgs),
     /// Runs edit prediction
     Predict(PredictArgs),
+    /// Parse model outputs (actual_output) into unified diffs (actual_patch).
+    /// Requires format-prompt to have been run first. Uses provider from prompt.
+    ParseOutput,
     /// Computes a score based on actual and expected patches
     Score(PredictArgs),
     /// Prepares a distillation dataset by copying expected outputs to
@@ -159,6 +163,7 @@ impl Display for Command {
             Command::Predict(args) => {
                 write!(f, "predict --provider={}", args.provider)
             }
+            Command::ParseOutput => write!(f, "parse-output"),
             Command::Score(args) => {
                 write!(f, "score --provider={}", args.provider)
             }
@@ -601,6 +606,9 @@ fn main() {
                                             )
                                             .await?;
                                         }
+                                        Command::ParseOutput => {
+                                            parse_output::run_parse_output(example)?;
+                                        }
                                         Command::Distill => {
                                             run_distill(example).await?;
                                         }

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -0,0 +1,234 @@
+use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt};
+use anyhow::{Context as _, Result};
+use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
+
+pub fn run_parse_output(example: &mut Example) -> Result<()> {
+    let provider = example
+        .prompt
+        .as_ref()
+        .context("prompt required (run format-prompt first)")?
+        .provider;
+    example
+        .prompt_inputs
+        .as_ref()
+        .context("prompt_inputs required")?;
+
+    let parsed_patches: Vec<_> = example
+        .predictions
+        .iter()
+        .enumerate()
+        .filter(|(_, p)| !p.actual_output.is_empty())
+        .map(|(ix, prediction)| {
+            let actual_patch =
+                parse_prediction_output(example, &prediction.actual_output, provider);
+            actual_patch.map(|patch| (ix, patch))
+        })
+        .collect::<Result<Vec<_>>>()?;
+
+    for (ix, actual_patch) in parsed_patches {
+        example.predictions[ix].actual_patch = Some(actual_patch);
+        example.predictions[ix].provider = provider;
+    }
+
+    Ok(())
+}
+
+pub fn parse_prediction_output(
+    example: &Example,
+    actual_output: &str,
+    provider: PredictionProvider,
+) -> Result<String> {
+    match provider {
+        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
+            TeacherPrompt::parse(example, actual_output)
+        }
+        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
+        _ => anyhow::bail!(
+            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
+            provider
+        ),
+    }
+}
+
+fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
+    let (current_marker, end_marker) = match version {
+        ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
+        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
+            ("<|fim_middle|>current\n", "<|fim_suffix|>")
+        }
+        ZetaVersion::V0120GitMergeMarkers => (
+            zeta_prompt::v0120_git_merge_markers::START_MARKER,
+            zeta_prompt::v0120_git_merge_markers::SEPARATOR,
+        ),
+    };
+
+    let start = prompt.find(current_marker).with_context(|| {
+        format!(
+            "missing current marker '{}' in prompt",
+            current_marker.trim()
+        )
+    })? + current_marker.len();
+
+    let end = prompt[start..]
+        .find(end_marker)
+        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
+        + start;
+
+    let region = &prompt[start..end];
+    let region = region.strip_suffix('\n').unwrap_or(region);
+    Ok(region.replace(CURSOR_MARKER, ""))
+}
+
+fn parse_zeta2_output(
+    example: &Example,
+    actual_output: &str,
+    version: ZetaVersion,
+) -> Result<String> {
+    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
+    let prompt_inputs = example
+        .prompt_inputs
+        .as_ref()
+        .context("prompt_inputs required")?;
+
+    let old_text = extract_zeta2_current_region(prompt, version)?;
+
+    let mut new_text = actual_output.replace(CURSOR_MARKER, "");
+
+    if version == ZetaVersion::V0120GitMergeMarkers {
+        if let Some(stripped) =
+            new_text.strip_suffix(zeta_prompt::v0120_git_merge_markers::END_MARKER)
+        {
+            new_text = stripped.to_string();
+        }
+    }
+
+    let mut old_text_normalized = old_text.clone();
+    if !new_text.is_empty() && !new_text.ends_with('\n') {
+        new_text.push('\n');
+    }
+    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
+        old_text_normalized.push('\n');
+    }
+
+    let old_text_trimmed = old_text.trim_end_matches('\n');
+    let (editable_region_offset, _) = prompt_inputs
+        .content
+        .match_indices(old_text_trimmed)
+        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
+        .with_context(|| {
+            format!(
+                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
+                old_text_trimmed, &prompt_inputs.content
+            )
+        })?;
+
+    let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
+        .matches('\n')
+        .count();
+
+    let diff = language::unified_diff_with_offsets(
+        &old_text_normalized,
+        &new_text,
+        editable_region_start_line as u32,
+        editable_region_start_line as u32,
+    );
+
+    let formatted_diff = format!(
+        "--- a/{path}\n+++ b/{path}\n{diff}",
+        path = example.spec.cursor_path.to_string_lossy(),
+    );
+
+    Ok(formatted_diff)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_extract_zeta2_current_region_v0113() {
+        let prompt = indoc::indoc! {"
+            <|file_sep|>src/main.rs
+            <|fim_prefix|>
+            fn main() {
+            <|fim_middle|>current
+            println!(\"hello\");
+            <|fim_suffix|>
+            }
+            <|fim_middle|>updated
+        "};
+
+        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+        assert_eq!(region, "println!(\"hello\");");
+    }
+
+    #[test]
+    fn test_extract_zeta2_current_region_v0112() {
+        let prompt = indoc::indoc! {"
+            <|file_sep|>src/main.rs
+            <|fim_prefix|>
+            fn main() {
+            <|fim_suffix|>
+            }
+            <|fim_middle|>current
+            println!(\"hello\");
+            <|fim_middle|>updated
+        "};
+
+        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
+        assert_eq!(region, "println!(\"hello\");");
+    }
+
+    #[test]
+    fn test_extract_zeta2_current_region_with_cursor_marker() {
+        let prompt = indoc::indoc! {"
+            <|file_sep|>src/main.rs
+            <|fim_prefix|>
+            fn main() {
+            <|fim_middle|>current
+            print<|user_cursor|>ln!(\"hello\");
+            <|fim_suffix|>
+            }
+            <|fim_middle|>updated
+        "};
+
+        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+        assert_eq!(region, "println!(\"hello\");");
+    }
+
+    #[test]
+    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
+        let prompt = indoc::indoc! {"
+            <|file_sep|>src/main.rs
+            <|fim_prefix|>
+            fn main() {
+            <|fim_suffix|>
+            }
+            <|fim_middle|><<<<<<< CURRENT
+            println!(\"hello\");
+            =======
+        "};
+
+        let region =
+            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+        assert_eq!(region, "println!(\"hello\");");
+    }
+
+    #[test]
+    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
+        let prompt = indoc::indoc! {"
+            <|file_sep|>src/main.rs
+            <|fim_prefix|>
+            fn main() {
+            <|fim_suffix|>
+            }
+            <|fim_middle|><<<<<<< CURRENT
+            print<|user_cursor|>ln!(\"hello\");
+            =======
+        "};
+
+        let region =
+            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+        assert_eq!(region, "println!(\"hello\");");
+    }
+}

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -186,7 +186,7 @@ pub async fn run_prediction(
             .unwrap()
             .predictions
             .push(ExamplePrediction {
-                actual_patch: String::new(),
+                actual_patch: None,
                 actual_output: String::new(),
                 provider,
             });
@@ -204,16 +204,14 @@ pub async fn run_prediction(
             })
             .await?;
 
-        let actual_patch = prediction
-            .and_then(|prediction| {
-                let prediction = prediction.prediction.ok()?;
-                prediction
-                    .edit_preview
-                    .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
-            })
-            .unwrap_or_default();
+        let actual_patch = prediction.and_then(|prediction| {
+            let prediction = prediction.prediction.ok()?;
+            prediction
+                .edit_preview
+                .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
+        });
 
-        let has_prediction = !actual_patch.is_empty();
+        let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
 
         updated_example
             .lock()
@@ -293,7 +291,7 @@ async fn predict_anthropic(
     let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
 
     let prediction = ExamplePrediction {
-        actual_patch,
+        actual_patch: Some(actual_patch),
         actual_output,
         provider: if batched {
             PredictionProvider::Teacher(version)

crates/edit_prediction_cli/src/score.rs 🔗

@@ -3,6 +3,7 @@ use crate::{
     example::{Example, ExampleScore},
     headless::EpAppState,
     metrics,
+    parse_output::parse_prediction_output,
     predict::run_prediction,
     progress::{ExampleProgress, Step},
 };
@@ -37,7 +38,27 @@ pub async fn run_scoring(
     progress.set_substatus("computing metrics");
     let mut scores = vec![];
     for prediction in &example.predictions {
-        let actual_text = match apply_diff_to_string(&prediction.actual_patch, original_text) {
+        let actual_patch = match &prediction.actual_patch {
+            Some(patch) => patch.clone(),
+            None => {
+                if prediction.actual_output.is_empty() {
+                    scores.push(ExampleScore { delta_chr_f: 0.0 });
+                    continue;
+                }
+                match parse_prediction_output(
+                    example,
+                    &prediction.actual_output,
+                    prediction.provider,
+                ) {
+                    Ok(patch) => patch,
+                    Err(_) => {
+                        scores.push(ExampleScore { delta_chr_f: 0.0 });
+                        continue;
+                    }
+                }
+            }
+        };
+        let actual_text = match apply_diff_to_string(&actual_patch, original_text) {
             Ok(text) => text,
             Err(_) => {
                 scores.push(ExampleScore { delta_chr_f: 0.0 });