parse_output.rs

  1use crate::{
  2    PredictionProvider,
  3    example::{ActualCursor, Example},
  4    format_prompt::TeacherPrompt,
  5    repair,
  6};
  7use anyhow::{Context as _, Result};
  8use edit_prediction::example_spec::encode_cursor_in_patch;
  9use zeta_prompt::{CURSOR_MARKER, ZetaFormat, output_end_marker_for_format, resolve_cursor_region};
 10
 11pub fn run_parse_output(example: &mut Example) -> Result<()> {
 12    example
 13        .prompt_inputs
 14        .as_ref()
 15        .context("prompt_inputs required")?;
 16
 17    let to_parse: Vec<_> = example
 18        .predictions
 19        .iter()
 20        .enumerate()
 21        .filter(|(_, p)| !p.actual_output.is_empty())
 22        .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
 23        .collect();
 24
 25    for (ix, actual_output, provider) in to_parse {
 26        let (actual_patch, actual_cursor) =
 27            parse_prediction_output(example, &actual_output, provider)?;
 28        example.predictions[ix].actual_patch = Some(actual_patch);
 29        example.predictions[ix].actual_cursor = actual_cursor;
 30    }
 31
 32    Ok(())
 33}
 34
 35pub fn parse_prediction_output(
 36    example: &Example,
 37    actual_output: &str,
 38    provider: PredictionProvider,
 39) -> Result<(String, Option<ActualCursor>)> {
 40    match provider {
 41        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 42            TeacherPrompt::parse(example, actual_output)
 43        }
 44        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 45        PredictionProvider::Repair => repair::parse(example, actual_output),
 46        _ => anyhow::bail!(
 47            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 48            provider
 49        ),
 50    }
 51}
 52
 53fn parse_zeta2_output(
 54    example: &Example,
 55    actual_output: &str,
 56    format: ZetaFormat,
 57) -> Result<(String, Option<ActualCursor>)> {
 58    let prompt_inputs = example
 59        .prompt_inputs
 60        .as_ref()
 61        .context("prompt_inputs required")?;
 62
 63    let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format);
 64    let old_text = context[editable_range].to_string();
 65
 66    let mut new_text = actual_output.to_string();
 67    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
 68        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
 69        Some(offset)
 70    } else {
 71        None
 72    };
 73
 74    if let Some(marker) = output_end_marker_for_format(format) {
 75        new_text = new_text
 76            .strip_suffix(marker)
 77            .unwrap_or(&new_text)
 78            .to_string();
 79    }
 80
 81    let mut old_text_normalized = old_text.clone();
 82    if !new_text.is_empty() && !new_text.ends_with('\n') {
 83        new_text.push('\n');
 84    }
 85    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
 86        old_text_normalized.push('\n');
 87    }
 88
 89    let old_text_trimmed = old_text.trim_end_matches('\n');
 90    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
 91    let (editable_region_offset, _) = excerpt
 92        .match_indices(old_text_trimmed)
 93        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
 94        .with_context(|| {
 95            format!(
 96                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
 97                old_text_trimmed, excerpt
 98            )
 99        })?;
100
101    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
102
103    // Use full context so cursor offset (relative to editable region start) aligns with diff content
104    let editable_region_lines = old_text_normalized.lines().count() as u32;
105    let diff = language::unified_diff_with_context(
106        &old_text_normalized,
107        &new_text,
108        editable_region_start_line as u32,
109        editable_region_start_line as u32,
110        editable_region_lines,
111    );
112
113    let formatted_diff = format!(
114        "--- a/{path}\n+++ b/{path}\n{diff}",
115        path = example.spec.cursor_path.to_string_lossy(),
116    );
117
118    let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
119
120    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
121        ActualCursor::from_editable_region(
122            &example.spec.cursor_path,
123            editable_region_cursor_offset,
124            &new_text,
125            excerpt,
126            editable_region_offset,
127            editable_region_start_line,
128        )
129    });
130
131    Ok((formatted_diff, actual_cursor))
132}