parse_output.rs

  1use crate::{
  2    PredictionProvider,
  3    example::{ActualCursor, Example},
  4    format_prompt::TeacherPrompt,
  5};
  6use anyhow::{Context as _, Result};
  7use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
  8
  9pub fn run_parse_output(example: &mut Example) -> Result<()> {
 10    let provider = example
 11        .prompt
 12        .as_ref()
 13        .context("prompt required (run format-prompt first)")?
 14        .provider;
 15    example
 16        .prompt_inputs
 17        .as_ref()
 18        .context("prompt_inputs required")?;
 19
 20    let parsed_patches: Vec<_> = example
 21        .predictions
 22        .iter()
 23        .enumerate()
 24        .filter(|(_, p)| !p.actual_output.is_empty())
 25        .map(|(ix, prediction)| {
 26            let result = parse_prediction_output(example, &prediction.actual_output, provider);
 27            result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
 28        })
 29        .collect::<Result<Vec<_>>>()?;
 30
 31    for (ix, actual_patch, actual_cursor) in parsed_patches {
 32        example.predictions[ix].actual_patch = Some(actual_patch);
 33        example.predictions[ix].actual_cursor = actual_cursor;
 34        example.predictions[ix].provider = provider;
 35    }
 36
 37    Ok(())
 38}
 39
 40pub fn parse_prediction_output(
 41    example: &Example,
 42    actual_output: &str,
 43    provider: PredictionProvider,
 44) -> Result<(String, Option<ActualCursor>)> {
 45    match provider {
 46        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 47            TeacherPrompt::parse(example, actual_output)
 48        }
 49        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 50        _ => anyhow::bail!(
 51            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 52            provider
 53        ),
 54    }
 55}
 56
 57fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
 58    let (current_marker, end_marker) = match version {
 59        ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
 60        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
 61            ("<|fim_middle|>current\n", "<|fim_suffix|>")
 62        }
 63        ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
 64            zeta_prompt::v0120_git_merge_markers::START_MARKER,
 65            zeta_prompt::v0120_git_merge_markers::SEPARATOR,
 66        ),
 67    };
 68
 69    let start = prompt.find(current_marker).with_context(|| {
 70        format!(
 71            "missing current marker '{}' in prompt",
 72            current_marker.trim()
 73        )
 74    })? + current_marker.len();
 75
 76    let end = prompt[start..]
 77        .find(end_marker)
 78        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
 79        + start;
 80
 81    let region = &prompt[start..end];
 82    let region = region.replace(CURSOR_MARKER, "");
 83
 84    Ok(region)
 85}
 86
 87fn parse_zeta2_output(
 88    example: &Example,
 89    actual_output: &str,
 90    version: ZetaVersion,
 91) -> Result<(String, Option<ActualCursor>)> {
 92    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
 93    let prompt_inputs = example
 94        .prompt_inputs
 95        .as_ref()
 96        .context("prompt_inputs required")?;
 97
 98    let old_text = extract_zeta2_current_region(prompt, version)?;
 99
100    let mut new_text = actual_output.to_string();
101    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
102        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
103        Some(offset)
104    } else {
105        None
106    };
107
108    let suffix = match version {
109        ZetaVersion::V0131GitMergeMarkersPrefix => {
110            zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
111        }
112        ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
113        _ => "",
114    };
115    if !suffix.is_empty() {
116        new_text = new_text
117            .strip_suffix(suffix)
118            .unwrap_or(&new_text)
119            .to_string();
120    }
121
122    let mut old_text_normalized = old_text.clone();
123    if !new_text.is_empty() && !new_text.ends_with('\n') {
124        new_text.push('\n');
125    }
126    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
127        old_text_normalized.push('\n');
128    }
129
130    let old_text_trimmed = old_text.trim_end_matches('\n');
131    let (editable_region_offset, _) = prompt_inputs
132        .content
133        .match_indices(old_text_trimmed)
134        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
135        .with_context(|| {
136            format!(
137                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
138                old_text_trimmed, &prompt_inputs.content
139            )
140        })?;
141
142    let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
143        .matches('\n')
144        .count();
145
146    // Use full context so cursor offset (relative to editable region start) aligns with diff content
147    let editable_region_lines = old_text_normalized.lines().count() as u32;
148    let diff = language::unified_diff_with_context(
149        &old_text_normalized,
150        &new_text,
151        editable_region_start_line as u32,
152        editable_region_start_line as u32,
153        editable_region_lines,
154    );
155
156    let formatted_diff = format!(
157        "--- a/{path}\n+++ b/{path}\n{diff}",
158        path = example.spec.cursor_path.to_string_lossy(),
159    );
160
161    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
162        ActualCursor::from_editable_region(
163            &example.spec.cursor_path,
164            editable_region_cursor_offset,
165            &new_text,
166            &prompt_inputs.content,
167            editable_region_offset,
168            editable_region_start_line,
169        )
170    });
171
172    Ok((formatted_diff, actual_cursor))
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn test_extract_zeta2_current_region_v0113() {
181        let prompt = indoc::indoc! {"
182            <|file_sep|>src/main.rs
183            <|fim_prefix|>
184            fn main() {
185            <|fim_middle|>current
186            println!(\"hello\");
187            <|fim_suffix|>
188            }
189            <|fim_middle|>updated
190        "};
191
192        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
193        assert_eq!(region, "println!(\"hello\");\n");
194    }
195
196    #[test]
197    fn test_extract_zeta2_current_region_v0112() {
198        let prompt = indoc::indoc! {"
199            <|file_sep|>src/main.rs
200            <|fim_prefix|>
201            fn main() {
202            <|fim_suffix|>
203            }
204            <|fim_middle|>current
205            println!(\"hello\");
206            <|fim_middle|>updated
207        "};
208
209        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
210        assert_eq!(region, "println!(\"hello\");\n");
211    }
212
213    #[test]
214    fn test_extract_zeta2_current_region_with_cursor_marker() {
215        let prompt = indoc::indoc! {"
216            <|file_sep|>src/main.rs
217            <|fim_prefix|>
218            fn main() {
219            <|fim_middle|>current
220            print<|user_cursor|>ln!(\"hello\");
221            <|fim_suffix|>
222            }
223            <|fim_middle|>updated
224        "};
225
226        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
227        assert_eq!(region, "println!(\"hello\");\n");
228    }
229
230    #[test]
231    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
232        let prompt = indoc::indoc! {"
233            <|file_sep|>src/main.rs
234            <|fim_prefix|>
235            fn main() {
236            <|fim_suffix|>
237            }
238            <|fim_middle|><<<<<<< CURRENT
239            println!(\"hello\");
240            =======
241        "};
242
243        let region =
244            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
245        assert_eq!(region, "println!(\"hello\");\n");
246    }
247
248    #[test]
249    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
250        let prompt = indoc::indoc! {"
251            <|file_sep|>src/main.rs
252            <|fim_prefix|>
253            fn main() {
254            <|fim_suffix|>
255            }
256            <|fim_middle|><<<<<<< CURRENT
257            print<|user_cursor|>ln!(\"hello\");
258            =======
259        "};
260
261        let region =
262            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
263        assert_eq!(region, "println!(\"hello\");\n");
264    }
265}