parse_output.rs

  1use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt};
  2use anyhow::{Context as _, Result};
  3use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
  4
  5pub fn run_parse_output(example: &mut Example) -> Result<()> {
  6    let provider = example
  7        .prompt
  8        .as_ref()
  9        .context("prompt required (run format-prompt first)")?
 10        .provider;
 11    example
 12        .prompt_inputs
 13        .as_ref()
 14        .context("prompt_inputs required")?;
 15
 16    let parsed_patches: Vec<_> = example
 17        .predictions
 18        .iter()
 19        .enumerate()
 20        .filter(|(_, p)| !p.actual_output.is_empty())
 21        .map(|(ix, prediction)| {
 22            let result = parse_prediction_output(example, &prediction.actual_output, provider);
 23            result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
 24        })
 25        .collect::<Result<Vec<_>>>()?;
 26
 27    for (ix, actual_patch, actual_cursor_offset) in parsed_patches {
 28        example.predictions[ix].actual_patch = Some(actual_patch);
 29        example.predictions[ix].actual_cursor_offset = actual_cursor_offset;
 30        example.predictions[ix].provider = provider;
 31    }
 32
 33    Ok(())
 34}
 35
 36pub fn parse_prediction_output(
 37    example: &Example,
 38    actual_output: &str,
 39    provider: PredictionProvider,
 40) -> Result<(String, Option<usize>)> {
 41    match provider {
 42        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 43            TeacherPrompt::parse(example, actual_output)
 44        }
 45        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
 46        _ => anyhow::bail!(
 47            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
 48            provider
 49        ),
 50    }
 51}
 52
 53fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
 54    let (current_marker, end_marker) = match version {
 55        ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
 56        ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
 57            ("<|fim_middle|>current\n", "<|fim_suffix|>")
 58        }
 59        ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
 60            zeta_prompt::v0120_git_merge_markers::START_MARKER,
 61            zeta_prompt::v0120_git_merge_markers::SEPARATOR,
 62        ),
 63    };
 64
 65    let start = prompt.find(current_marker).with_context(|| {
 66        format!(
 67            "missing current marker '{}' in prompt",
 68            current_marker.trim()
 69        )
 70    })? + current_marker.len();
 71
 72    let end = prompt[start..]
 73        .find(end_marker)
 74        .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
 75        + start;
 76
 77    let region = &prompt[start..end];
 78    let region = region.strip_suffix('\n').unwrap_or(region);
 79    let region = region.replace(CURSOR_MARKER, "");
 80
 81    Ok(region)
 82}
 83
 84fn parse_zeta2_output(
 85    example: &Example,
 86    actual_output: &str,
 87    version: ZetaVersion,
 88) -> Result<(String, Option<usize>)> {
 89    let prompt = &example.prompt.as_ref().context("prompt required")?.input;
 90    let prompt_inputs = example
 91        .prompt_inputs
 92        .as_ref()
 93        .context("prompt_inputs required")?;
 94
 95    let old_text = extract_zeta2_current_region(prompt, version)?;
 96
 97    let mut new_text = actual_output.to_string();
 98    let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
 99        new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
100        Some(offset)
101    } else {
102        None
103    };
104
105    let suffix = match version {
106        ZetaVersion::V0131GitMergeMarkersPrefix => {
107            zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
108        }
109        ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
110        _ => "",
111    };
112    if !suffix.is_empty() {
113        new_text = new_text
114            .strip_suffix(suffix)
115            .unwrap_or(&new_text)
116            .to_string();
117    }
118
119    let mut old_text_normalized = old_text.clone();
120    if !new_text.is_empty() && !new_text.ends_with('\n') {
121        new_text.push('\n');
122    }
123    if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
124        old_text_normalized.push('\n');
125    }
126
127    let old_text_trimmed = old_text.trim_end_matches('\n');
128    let (editable_region_offset, _) = prompt_inputs
129        .content
130        .match_indices(old_text_trimmed)
131        .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
132        .with_context(|| {
133            format!(
134                "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
135                old_text_trimmed, &prompt_inputs.content
136            )
137        })?;
138
139    let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
140        .matches('\n')
141        .count();
142
143    // Use full context so cursor offset (relative to editable region start) aligns with diff content
144    let editable_region_lines = old_text_normalized.lines().count() as u32;
145    let diff = language::unified_diff_with_context(
146        &old_text_normalized,
147        &new_text,
148        editable_region_start_line as u32,
149        editable_region_start_line as u32,
150        editable_region_lines,
151    );
152
153    let formatted_diff = format!(
154        "--- a/{path}\n+++ b/{path}\n{diff}",
155        path = example.spec.cursor_path.to_string_lossy(),
156    );
157
158    Ok((formatted_diff, cursor_offset))
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_extract_zeta2_current_region_v0113() {
167        let prompt = indoc::indoc! {"
168            <|file_sep|>src/main.rs
169            <|fim_prefix|>
170            fn main() {
171            <|fim_middle|>current
172            println!(\"hello\");
173            <|fim_suffix|>
174            }
175            <|fim_middle|>updated
176        "};
177
178        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
179        assert_eq!(region, "println!(\"hello\");");
180    }
181
182    #[test]
183    fn test_extract_zeta2_current_region_v0112() {
184        let prompt = indoc::indoc! {"
185            <|file_sep|>src/main.rs
186            <|fim_prefix|>
187            fn main() {
188            <|fim_suffix|>
189            }
190            <|fim_middle|>current
191            println!(\"hello\");
192            <|fim_middle|>updated
193        "};
194
195        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
196        assert_eq!(region, "println!(\"hello\");");
197    }
198
199    #[test]
200    fn test_extract_zeta2_current_region_with_cursor_marker() {
201        let prompt = indoc::indoc! {"
202            <|file_sep|>src/main.rs
203            <|fim_prefix|>
204            fn main() {
205            <|fim_middle|>current
206            print<|user_cursor|>ln!(\"hello\");
207            <|fim_suffix|>
208            }
209            <|fim_middle|>updated
210        "};
211
212        let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
213        assert_eq!(region, "println!(\"hello\");");
214    }
215
216    #[test]
217    fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
218        let prompt = indoc::indoc! {"
219            <|file_sep|>src/main.rs
220            <|fim_prefix|>
221            fn main() {
222            <|fim_suffix|>
223            }
224            <|fim_middle|><<<<<<< CURRENT
225            println!(\"hello\");
226            =======
227        "};
228
229        let region =
230            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
231        assert_eq!(region, "println!(\"hello\");");
232    }
233
234    #[test]
235    fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
236        let prompt = indoc::indoc! {"
237            <|file_sep|>src/main.rs
238            <|fim_prefix|>
239            fn main() {
240            <|fim_suffix|>
241            }
242            <|fim_middle|><<<<<<< CURRENT
243            print<|user_cursor|>ln!(\"hello\");
244            =======
245        "};
246
247        let region =
248            extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
249        assert_eq!(region, "println!(\"hello\");");
250    }
251}