parse_output.rs

  1use crate::{
  2    PredictionProvider,
  3    example::{ActualCursor, Example},
  4    format_prompt::TeacherPrompt,
  5    repair,
  6};
  7use anyhow::{Context as _, Result};
  8use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
  9
 10pub fn run_parse_output(example: &mut Example) -> Result<()> {
 11    example
 12        .prompt_inputs
 13        .as_ref()
 14        .context("prompt_inputs required")?;
 15
 16    let to_parse: Vec<_> = example
 17        .predictions
 18        .iter()
 19        .enumerate()
 20        .filter(|(_, p)| !p.actual_output.is_empty())
 21        .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
 22        .collect();
 23
 24    for (ix, actual_output, provider) in to_parse {
 25        let (actual_patch, actual_cursor) =
 26            parse_prediction_output(example, &actual_output, provider)?;
 27        example.predictions[ix].actual_patch = Some(actual_patch);
 28        example.predictions[ix].actual_cursor = actual_cursor;
 29    }
 30
 31    Ok(())
 32}
 33
 34pub fn parse_prediction_output(
 35    example: &Example,
 36    actual_output: &str,
 37    provider: PredictionProvider,
 38) -> Result<(String, Option<ActualCursor>)> {
 39    match provider {
 40        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 41            TeacherPrompt::parse(example, actual_output)
 42        }
 43        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 44        PredictionProvider::Repair => repair::parse(example, actual_output),
 45        _ => anyhow::bail!(
 46            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 47            provider
 48        ),
 49    }
 50}
 51
 52fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
 53    let (current_marker, end_marker) = match format {
 54        ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
 55        ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
 56            ("<|fim_middle|>current\n", "<|fim_suffix|>")
 57        }
 58        ZetaFormat::V0120GitMergeMarkers
 59        | ZetaFormat::V0131GitMergeMarkersPrefix
 60        | ZetaFormat::V0211Prefill => (
 61            zeta_prompt::v0120_git_merge_markers::START_MARKER,
 62            zeta_prompt::v0120_git_merge_markers::SEPARATOR,
 63        ),
 64        ZetaFormat::V0211SeedCoder => (
 65            zeta_prompt::seed_coder::START_MARKER,
 66            zeta_prompt::seed_coder::SEPARATOR,
 67        ),
 68    };
 69
 70    let start = prompt.find(current_marker).with_context(|| {
 71        format!(
 72            "missing current marker '{}' in prompt",
 73            current_marker.trim()
 74        )
 75    })? + current_marker.len();
 76
 77    let end = prompt[start..]
 78        .find(end_marker)
 79        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
 80        + start;
 81
 82    let region = &prompt[start..end];
 83    let region = region.replace(CURSOR_MARKER, "");
 84
 85    Ok(region)
 86}
 87
 88fn parse_zeta2_output(
 89    example: &Example,
 90    actual_output: &str,
 91    format: ZetaFormat,
 92) -> Result<(String, Option<ActualCursor>)> {
 93    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
 94    let prompt_inputs = example
 95        .prompt_inputs
 96        .as_ref()
 97        .context("prompt_inputs required")?;
 98
 99    let old_text = extract_zeta2_current_region(prompt, format)?;
100
101    let mut new_text = actual_output.to_string();
102    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
103        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
104        Some(offset)
105    } else {
106        None
107    };
108
109    let suffix = match format {
110        ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
111            zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
112        }
113        ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
114        ZetaFormat::V0112MiddleAtEnd
115        | ZetaFormat::V0113Ordered
116        | ZetaFormat::V0114180EditableRegion => "",
117        ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER,
118    };
119    if !suffix.is_empty() {
120        new_text = new_text
121            .strip_suffix(suffix)
122            .unwrap_or(&new_text)
123            .to_string();
124    }
125
126    let mut old_text_normalized = old_text.clone();
127    if !new_text.is_empty() && !new_text.ends_with('\n') {
128        new_text.push('\n');
129    }
130    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
131        old_text_normalized.push('\n');
132    }
133
134    let old_text_trimmed = old_text.trim_end_matches('\n');
135    let (editable_region_offset, _) = prompt_inputs
136        .content
137        .match_indices(old_text_trimmed)
138        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
139        .with_context(|| {
140            format!(
141                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
142                old_text_trimmed, &prompt_inputs.content
143            )
144        })?;
145
146    let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
147        .matches('\n')
148        .count();
149
150    // Use full context so cursor offset (relative to editable region start) aligns with diff content
151    let editable_region_lines = old_text_normalized.lines().count() as u32;
152    let diff = language::unified_diff_with_context(
153        &old_text_normalized,
154        &new_text,
155        editable_region_start_line as u32,
156        editable_region_start_line as u32,
157        editable_region_lines,
158    );
159
160    let formatted_diff = format!(
161        "--- a/{path}\n+++ b/{path}\n{diff}",
162        path = example.spec.cursor_path.to_string_lossy(),
163    );
164
165    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
166        ActualCursor::from_editable_region(
167            &example.spec.cursor_path,
168            editable_region_cursor_offset,
169            &new_text,
170            &prompt_inputs.content,
171            editable_region_offset,
172            editable_region_start_line,
173        )
174    });
175
176    Ok((formatted_diff, actual_cursor))
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_extract_zeta2_current_region_v0113() {
185        let prompt = indoc::indoc! {"
186            <|file_sep|>src/main.rs
187            <|fim_prefix|>
188            fn main() {
189            <|fim_middle|>current
190            println!(\"hello\");
191            <|fim_suffix|>
192            }
193            <|fim_middle|>updated
194        "};
195
196        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
197        assert_eq!(region, "println!(\"hello\");\n");
198    }
199
200    #[test]
201    fn test_extract_zeta2_current_region_v0112() {
202        let prompt = indoc::indoc! {"
203            <|file_sep|>src/main.rs
204            <|fim_prefix|>
205            fn main() {
206            <|fim_suffix|>
207            }
208            <|fim_middle|>current
209            println!(\"hello\");
210            <|fim_middle|>updated
211        "};
212
213        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
214        assert_eq!(region, "println!(\"hello\");\n");
215    }
216
217    #[test]
218    fn test_extract_zeta2_current_region_with_cursor_marker() {
219        let prompt = indoc::indoc! {"
220            <|file_sep|>src/main.rs
221            <|fim_prefix|>
222            fn main() {
223            <|fim_middle|>current
224            print<|user_cursor|>ln!(\"hello\");
225            <|fim_suffix|>
226            }
227            <|fim_middle|>updated
228        "};
229
230        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
231        assert_eq!(region, "println!(\"hello\");\n");
232    }
233
234    #[test]
235    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
236        let prompt = indoc::indoc! {"
237            <|file_sep|>src/main.rs
238            <|fim_prefix|>
239            fn main() {
240            <|fim_suffix|>
241            }
242            <|fim_middle|><<<<<<< CURRENT
243            println!(\"hello\");
244            =======
245        "};
246
247        let region =
248            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
249        assert_eq!(region, "println!(\"hello\");\n");
250    }
251
252    #[test]
253    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
254        let prompt = indoc::indoc! {"
255            <|file_sep|>src/main.rs
256            <|fim_prefix|>
257            fn main() {
258            <|fim_suffix|>
259            }
260            <|fim_middle|><<<<<<< CURRENT
261            print<|user_cursor|>ln!(\"hello\");
262            =======
263        "};
264
265        let region =
266            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
267        assert_eq!(region, "println!(\"hello\");\n");
268    }
269}