parse_output.rs

  1use crate::{
  2    PredictionProvider,
  3    example::{ActualCursor, Example},
  4    format_prompt::TeacherPrompt,
  5    repair,
  6};
  7use anyhow::{Context as _, Result};
  8use edit_prediction::example_spec::encode_cursor_in_patch;
  9use zeta_prompt::{
 10    CURSOR_MARKER, ZetaFormat, clean_extracted_region_for_format,
 11    current_region_markers_for_format, output_end_marker_for_format,
 12    output_with_context_for_format,
 13};
 14
 15pub fn run_parse_output(example: &mut Example) -> Result<()> {
 16    example
 17        .prompt_inputs
 18        .as_ref()
 19        .context("prompt_inputs required")?;
 20
 21    let to_parse: Vec<_> = example
 22        .predictions
 23        .iter()
 24        .enumerate()
 25        .filter(|(_, p)| !p.actual_output.is_empty())
 26        .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
 27        .collect();
 28
 29    for (ix, actual_output, provider) in to_parse {
 30        let (actual_patch, actual_cursor) =
 31            parse_prediction_output(example, &actual_output, provider)?;
 32        example.predictions[ix].actual_patch = Some(actual_patch);
 33        example.predictions[ix].actual_cursor = actual_cursor;
 34    }
 35
 36    Ok(())
 37}
 38
 39pub fn parse_prediction_output(
 40    example: &Example,
 41    actual_output: &str,
 42    provider: PredictionProvider,
 43) -> Result<(String, Option<ActualCursor>)> {
 44    match provider {
 45        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 46            TeacherPrompt::parse(example, actual_output)
 47        }
 48        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 49        PredictionProvider::Repair => repair::parse(example, actual_output),
 50        _ => anyhow::bail!(
 51            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 52            provider
 53        ),
 54    }
 55}
 56
 57fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
 58    let (current_marker, end_marker) = current_region_markers_for_format(format);
 59
 60    let start = prompt.find(current_marker).with_context(|| {
 61        format!(
 62            "missing current marker '{}' in prompt",
 63            current_marker.trim()
 64        )
 65    })? + current_marker.len();
 66
 67    let end = prompt[start..]
 68        .find(end_marker)
 69        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
 70        + start;
 71
 72    let region = &prompt[start..end];
 73    let region = region.replace(CURSOR_MARKER, "");
 74    Ok(clean_extracted_region_for_format(format, &region))
 75}
 76
 77fn parse_zeta2_output(
 78    example: &Example,
 79    actual_output: &str,
 80    format: ZetaFormat,
 81) -> Result<(String, Option<ActualCursor>)> {
 82    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
 83    let prompt_inputs = example
 84        .prompt_inputs
 85        .as_ref()
 86        .context("prompt_inputs required")?;
 87
 88    let old_text = extract_zeta2_current_region(prompt, format)?;
 89
 90    let mut new_text = actual_output.to_string();
 91    if let Some(transformed) = output_with_context_for_format(format, &old_text, &new_text)? {
 92        new_text = transformed;
 93    }
 94    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
 95        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
 96        Some(offset)
 97    } else {
 98        None
 99    };
100
101    if let Some(marker) = output_end_marker_for_format(format) {
102        new_text = new_text
103            .strip_suffix(marker)
104            .unwrap_or(&new_text)
105            .to_string();
106    }
107
108    let mut old_text_normalized = old_text.clone();
109    if !new_text.is_empty() && !new_text.ends_with('\n') {
110        new_text.push('\n');
111    }
112    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
113        old_text_normalized.push('\n');
114    }
115
116    let old_text_trimmed = old_text.trim_end_matches('\n');
117    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
118    let (editable_region_offset, _) = excerpt
119        .match_indices(old_text_trimmed)
120        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
121        .with_context(|| {
122            format!(
123                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
124                old_text_trimmed, excerpt
125            )
126        })?;
127
128    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
129
130    // Use full context so cursor offset (relative to editable region start) aligns with diff content
131    let editable_region_lines = old_text_normalized.lines().count() as u32;
132    let diff = language::unified_diff_with_context(
133        &old_text_normalized,
134        &new_text,
135        editable_region_start_line as u32,
136        editable_region_start_line as u32,
137        editable_region_lines,
138    );
139
140    let formatted_diff = format!(
141        "--- a/{path}\n+++ b/{path}\n{diff}",
142        path = example.spec.cursor_path.to_string_lossy(),
143    );
144
145    let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
146
147    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
148        ActualCursor::from_editable_region(
149            &example.spec.cursor_path,
150            editable_region_cursor_offset,
151            &new_text,
152            excerpt,
153            editable_region_offset,
154            editable_region_start_line,
155        )
156    });
157
158    Ok((formatted_diff, actual_cursor))
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_extract_zeta2_current_region_v0113() {
167        let prompt = indoc::indoc! {"
168            <|file_sep|>src/main.rs
169            <|fim_prefix|>
170            fn main() {
171            <|fim_middle|>current
172            println!(\"hello\");
173            <|fim_suffix|>
174            }
175            <|fim_middle|>updated
176        "};
177
178        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
179        assert_eq!(region, "println!(\"hello\");\n");
180    }
181
182    #[test]
183    fn test_extract_zeta2_current_region_v0112() {
184        let prompt = indoc::indoc! {"
185            <|file_sep|>src/main.rs
186            <|fim_prefix|>
187            fn main() {
188            <|fim_suffix|>
189            }
190            <|fim_middle|>current
191            println!(\"hello\");
192            <|fim_middle|>updated
193        "};
194
195        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
196        assert_eq!(region, "println!(\"hello\");\n");
197    }
198
199    #[test]
200    fn test_extract_zeta2_current_region_with_cursor_marker() {
201        let prompt = indoc::indoc! {"
202            <|file_sep|>src/main.rs
203            <|fim_prefix|>
204            fn main() {
205            <|fim_middle|>current
206            print<|user_cursor|>ln!(\"hello\");
207            <|fim_suffix|>
208            }
209            <|fim_middle|>updated
210        "};
211
212        let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
213        assert_eq!(region, "println!(\"hello\");\n");
214    }
215
216    #[test]
217    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
218        let prompt = indoc::indoc! {"
219            <|file_sep|>src/main.rs
220            <|fim_prefix|>
221            fn main() {
222            <|fim_suffix|>
223            }
224            <|fim_middle|><<<<<<< CURRENT
225            println!(\"hello\");
226            =======
227        "};
228
229        let region =
230            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
231        assert_eq!(region, "println!(\"hello\");\n");
232    }
233
234    #[test]
235    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
236        let prompt = indoc::indoc! {"
237            <|file_sep|>src/main.rs
238            <|fim_prefix|>
239            fn main() {
240            <|fim_suffix|>
241            }
242            <|fim_middle|><<<<<<< CURRENT
243            print<|user_cursor|>ln!(\"hello\");
244            =======
245        "};
246
247        let region =
248            extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
249        assert_eq!(region, "println!(\"hello\");\n");
250    }
251}