parse_output.rs

 1use crate::{
 2    PredictionProvider,
 3    example::{ActualCursor, Example},
 4    format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt},
 5    repair,
 6};
 7use anyhow::{Context as _, Result};
 8use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch};
 9
10pub fn run_parse_output(example: &mut Example) -> Result<()> {
11    example
12        .prompt_inputs
13        .as_ref()
14        .context("prompt_inputs required")?;
15
16    let to_parse: Vec<_> = example
17        .predictions
18        .iter()
19        .enumerate()
20        .filter(|(_, p)| !p.actual_output.is_empty())
21        .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
22        .collect();
23
24    for (ix, actual_output, provider) in to_parse {
25        let (actual_patch, actual_cursor) =
26            parse_prediction_output(example, &actual_output, provider)?;
27        example.predictions[ix].actual_patch = Some(actual_patch);
28        example.predictions[ix].actual_cursor = actual_cursor;
29    }
30
31    Ok(())
32}
33
34pub fn parse_prediction_output(
35    example: &Example,
36    actual_output: &str,
37    provider: PredictionProvider,
38) -> Result<(String, Option<ActualCursor>)> {
39    match provider {
40        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
41            TeacherPrompt::parse(example, actual_output)
42        }
43        PredictionProvider::TeacherMultiRegion(_)
44        | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
45            TeacherMultiRegionPrompt::parse(example, actual_output)
46        }
47        PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
48        PredictionProvider::Repair => repair::parse(example, actual_output),
49        _ => anyhow::bail!(
50            "parse-output only supports Teacher and Zeta2 providers, got {:?}",
51            provider
52        ),
53    }
54}
55
56fn parse_zeta2_output(
57    example: &Example,
58    actual_output: &str,
59    format: ZetaFormat,
60) -> Result<(String, Option<ActualCursor>)> {
61    let prompt_inputs = example
62        .prompt_inputs
63        .as_ref()
64        .context("prompt_inputs required")?;
65
66    let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
67    let range_in_excerpt = parsed.range_in_excerpt.clone();
68    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
69    let editable_region_offset = range_in_excerpt.start;
70    let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
71
72    let mut new_text = parsed.new_editable_region.clone();
73    if !new_text.is_empty() && !new_text.ends_with('\n') {
74        new_text.push('\n');
75    }
76
77    let cursor_offset = parsed.cursor_offset_in_new_editable_region;
78    let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?;
79
80    let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
81        ActualCursor::from_editable_region(
82            &example.spec.cursor_path,
83            editable_region_cursor_offset,
84            &new_text,
85            excerpt,
86            editable_region_offset,
87            editable_region_start_line,
88        )
89    });
90
91    Ok((formatted_diff, actual_cursor))
92}