1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt},
5 repair,
6};
7use anyhow::{Context as _, Result};
8use zeta_prompt::{ZetaFormat, parse_zeta2_model_output, parsed_output_to_patch};
9
10pub fn run_parse_output(example: &mut Example) -> Result<()> {
11 example
12 .prompt_inputs
13 .as_ref()
14 .context("prompt_inputs required")?;
15
16 let to_parse: Vec<_> = example
17 .predictions
18 .iter()
19 .enumerate()
20 .filter(|(_, p)| !p.actual_output.is_empty())
21 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
22 .collect();
23
24 for (ix, actual_output, provider) in to_parse {
25 let (actual_patch, actual_cursor) =
26 parse_prediction_output(example, &actual_output, provider)?;
27 example.predictions[ix].actual_patch = Some(actual_patch);
28 example.predictions[ix].actual_cursor = actual_cursor;
29 }
30
31 Ok(())
32}
33
34pub fn parse_prediction_output(
35 example: &Example,
36 actual_output: &str,
37 provider: PredictionProvider,
38) -> Result<(String, Option<ActualCursor>)> {
39 match provider {
40 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
41 TeacherPrompt::parse(example, actual_output)
42 }
43 PredictionProvider::TeacherMultiRegion(_)
44 | PredictionProvider::TeacherMultiRegionNonBatching(_) => {
45 TeacherMultiRegionPrompt::parse(example, actual_output)
46 }
47 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
48 PredictionProvider::Repair => repair::parse(example, actual_output),
49 _ => anyhow::bail!(
50 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
51 provider
52 ),
53 }
54}
55
56fn parse_zeta2_output(
57 example: &Example,
58 actual_output: &str,
59 format: ZetaFormat,
60) -> Result<(String, Option<ActualCursor>)> {
61 let prompt_inputs = example
62 .prompt_inputs
63 .as_ref()
64 .context("prompt_inputs required")?;
65
66 let parsed = parse_zeta2_model_output(actual_output, format, prompt_inputs)?;
67 let range_in_excerpt = parsed.range_in_excerpt.clone();
68 let excerpt = prompt_inputs.cursor_excerpt.as_ref();
69 let editable_region_offset = range_in_excerpt.start;
70 let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
71
72 let mut new_text = parsed.new_editable_region.clone();
73 if !new_text.is_empty() && !new_text.ends_with('\n') {
74 new_text.push('\n');
75 }
76
77 let cursor_offset = parsed.cursor_offset_in_new_editable_region;
78 let formatted_diff = parsed_output_to_patch(prompt_inputs, parsed)?;
79
80 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
81 ActualCursor::from_editable_region(
82 &example.spec.cursor_path,
83 editable_region_cursor_offset,
84 &new_text,
85 excerpt,
86 editable_region_offset,
87 editable_region_start_line,
88 )
89 });
90
91 Ok((formatted_diff, actual_cursor))
92}