parse_output.rs

  1use crate::{
  2    PredictionProvider,
  3    example::{ActualCursor, Example},
  4    format_prompt::TeacherPrompt,
  5    repair,
  6};
  7use anyhow::{Context as _, Result};
  8use edit_prediction::example_spec::encode_cursor_in_patch;
  9use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
 10
 11pub fn run_parse_output(example: &mut Example) -> Result<()> {
 12    example
 13        .prompt_inputs
 14        .as_ref()
 15        .context("prompt_inputs required")?;
 16
 17    let to_parse: Vec<_> = example
 18        .predictions
 19        .iter()
 20        .enumerate()
 21        .filter(|(_, p)| !p.actual_output.is_empty())
 22        .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
 23        .collect();
 24
 25    for (ix, actual_output, provider) in to_parse {
 26        let (actual_patch, actual_cursor) =
 27            parse_prediction_output(example, &actual_output, provider)?;
 28        example.predictions[ix].actual_patch = Some(actual_patch);
 29        example.predictions[ix].actual_cursor = actual_cursor;
 30    }
 31
 32    Ok(())
 33}
 34
 35pub fn parse_prediction_output(
 36    example: &Example,
 37    actual_output: &str,
 38    provider: PredictionProvider,
 39) -> Result<(String, Option<ActualCursor>)> {
 40    match provider {
 41        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 42            TeacherPrompt::parse(example, actual_output)
 43        }
 44        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 45        PredictionProvider::Repair => repair::parse(example, actual_output),
 46        _ => anyhow::bail!(
 47            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 48            provider
 49        ),
 50    }
 51}
 52
 53fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
 54    let (current_marker, end_marker) = match format {
 55        ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
 56        ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
 57            ("<|fim_middle|>current\n", "<|fim_suffix|>")
 58        }
 59        ZetaFormat::V0120GitMergeMarkers
 60        | ZetaFormat::V0131GitMergeMarkersPrefix
 61        | ZetaFormat::V0211Prefill => (
 62            zeta_prompt::v0120_git_merge_markers::START_MARKER,
 63            zeta_prompt::v0120_git_merge_markers::SEPARATOR,
 64        ),
 65        ZetaFormat::V0211SeedCoder => (
 66            zeta_prompt::seed_coder::START_MARKER,
 67            zeta_prompt::seed_coder::SEPARATOR,
 68        ),
 69    };
 70
 71    let start = prompt.find(current_marker).with_context(|| {
 72        format!(
 73            "missing current marker '{}' in prompt",
 74            current_marker.trim()
 75        )
 76    })? + current_marker.len();
 77
 78    let end = prompt[start..]
 79        .find(end_marker)
 80        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
 81        + start;
 82
 83    let region = &prompt[start..end];
 84    let region = region.replace(CURSOR_MARKER, "");
 85
 86    Ok(region)
 87}
 88
 89fn parse_zeta2_output(
 90    example: &Example,
 91    actual_output: &str,
 92    format: ZetaFormat,
 93) -> Result<(String, Option<ActualCursor>)> {
 94    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
 95    let prompt_inputs = example
 96        .prompt_inputs
 97        .as_ref()
 98        .context("prompt_inputs required")?;
 99
100    let old_text = extract_zeta2_current_region(prompt, format)?;
101
102    let mut new_text = actual_output.to_string();
103    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
104        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
105        Some(offset)
106    } else {
107        None
108    };
109
110    let suffix = match format {
111        ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
112            zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
113        }
114        ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
115        ZetaFormat::V0112MiddleAtEnd
116        | ZetaFormat::V0113Ordered
117        | ZetaFormat::V0114180EditableRegion => "",
118        ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER,
119    };
120    if !suffix.is_empty() {
121        new_text = new_text
122            .strip_suffix(suffix)
123            .unwrap_or(&new_text)
124            .to_string();
125    }
126
127    let mut old_text_normalized = old_text.clone();
128    if !new_text.is_empty() && !new_text.ends_with('\n') {
129        new_text.push('\n');
130    }
131    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
132        old_text_normalized.push('\n');
133    }
134
135    let old_text_trimmed = old_text.trim_end_matches('\n');
136    let (editable_region_offset, _) = prompt_inputs
137        .content
138        .match_indices(old_text_trimmed)
139        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
140        .with_context(|| {
141            format!(
142                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
143                old_text_trimmed, &prompt_inputs.content
144            )
145        })?;
146
147    let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
148        .matches('\n')
149        .count();
150
151    // Use full context so cursor offset (relative to editable region start) aligns with diff content
152    let editable_region_lines = old_text_normalized.lines().count() as u32;
153    let diff = language::unified_diff_with_context(
154        &old_text_normalized,
155        &new_text,
156        editable_region_start_line as u32,
157        editable_region_start_line as u32,
158        editable_region_lines,
159    );
160
161    let formatted_diff = format!(
162        "--- a/{path}\n+++ b/{path}\n{diff}",
163        path = example.spec.cursor_path.to_string_lossy(),
164    );
165
166    let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
167
168    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
169        ActualCursor::from_editable_region(
170            &example.spec.cursor_path,
171            editable_region_cursor_offset,
172            &new_text,
173            &prompt_inputs.content,
174            editable_region_offset,
175            editable_region_start_line,
176        )
177    });
178
179    Ok((formatted_diff, actual_cursor))
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_extract_zeta2_current_region_v0113() {
188        let prompt = indoc::indoc! {"
189            <|file_sep|>src/main.rs
190            <|fim_prefix|>
191            fn main() {
192            <|fim_middle|>current
193            println!(\"hello\");
194            <|fim_suffix|>
195            }
196            <|fim_middle|>updated
197        "};
198
199        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
200        assert_eq!(region, "println!(\"hello\");\n");
201    }
202
203    #[test]
204    fn test_extract_zeta2_current_region_v0112() {
205        let prompt = indoc::indoc! {"
206            <|file_sep|>src/main.rs
207            <|fim_prefix|>
208            fn main() {
209            <|fim_suffix|>
210            }
211            <|fim_middle|>current
212            println!(\"hello\");
213            <|fim_middle|>updated
214        "};
215
216        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
217        assert_eq!(region, "println!(\"hello\");\n");
218    }
219
220    #[test]
221    fn test_extract_zeta2_current_region_with_cursor_marker() {
222        let prompt = indoc::indoc! {"
223            <|file_sep|>src/main.rs
224            <|fim_prefix|>
225            fn main() {
226            <|fim_middle|>current
227            print<|user_cursor|>ln!(\"hello\");
228            <|fim_suffix|>
229            }
230            <|fim_middle|>updated
231        "};
232
233        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
234        assert_eq!(region, "println!(\"hello\");\n");
235    }
236
237    #[test]
238    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
239        let prompt = indoc::indoc! {"
240            <|file_sep|>src/main.rs
241            <|fim_prefix|>
242            fn main() {
243            <|fim_suffix|>
244            }
245            <|fim_middle|><<<<<<< CURRENT
246            println!(\"hello\");
247            =======
248        "};
249
250        let region =
251            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
252        assert_eq!(region, "println!(\"hello\");\n");
253    }
254
255    #[test]
256    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
257        let prompt = indoc::indoc! {"
258            <|file_sep|>src/main.rs
259            <|fim_prefix|>
260            fn main() {
261            <|fim_suffix|>
262            }
263            <|fim_middle|><<<<<<< CURRENT
264            print<|user_cursor|>ln!(\"hello\");
265            =======
266        "};
267
268        let region =
269            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
270        assert_eq!(region, "println!(\"hello\");\n");
271    }
272}