1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5 repair,
6};
7use anyhow::{Context as _, Result};
8use edit_prediction::example_spec::encode_cursor_in_patch;
9use zeta_prompt::{CURSOR_MARKER, ZetaFormat, output_end_marker_for_format, resolve_cursor_region};
10
11pub fn run_parse_output(example: &mut Example) -> Result<()> {
12 example
13 .prompt_inputs
14 .as_ref()
15 .context("prompt_inputs required")?;
16
17 let to_parse: Vec<_> = example
18 .predictions
19 .iter()
20 .enumerate()
21 .filter(|(_, p)| !p.actual_output.is_empty())
22 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
23 .collect();
24
25 for (ix, actual_output, provider) in to_parse {
26 let (actual_patch, actual_cursor) =
27 parse_prediction_output(example, &actual_output, provider)?;
28 example.predictions[ix].actual_patch = Some(actual_patch);
29 example.predictions[ix].actual_cursor = actual_cursor;
30 }
31
32 Ok(())
33}
34
35pub fn parse_prediction_output(
36 example: &Example,
37 actual_output: &str,
38 provider: PredictionProvider,
39) -> Result<(String, Option<ActualCursor>)> {
40 match provider {
41 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
42 TeacherPrompt::parse(example, actual_output)
43 }
44 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
45 PredictionProvider::Repair => repair::parse(example, actual_output),
46 _ => anyhow::bail!(
47 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
48 provider
49 ),
50 }
51}
52
53fn parse_zeta2_output(
54 example: &Example,
55 actual_output: &str,
56 format: ZetaFormat,
57) -> Result<(String, Option<ActualCursor>)> {
58 let prompt_inputs = example
59 .prompt_inputs
60 .as_ref()
61 .context("prompt_inputs required")?;
62
63 let (context, editable_range, _, _) = resolve_cursor_region(prompt_inputs, format);
64 let old_text = context[editable_range].to_string();
65
66 let mut new_text = actual_output.to_string();
67 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
68 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
69 Some(offset)
70 } else {
71 None
72 };
73
74 if let Some(marker) = output_end_marker_for_format(format) {
75 new_text = new_text
76 .strip_suffix(marker)
77 .unwrap_or(&new_text)
78 .to_string();
79 }
80
81 let mut old_text_normalized = old_text.clone();
82 if !new_text.is_empty() && !new_text.ends_with('\n') {
83 new_text.push('\n');
84 }
85 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
86 old_text_normalized.push('\n');
87 }
88
89 let old_text_trimmed = old_text.trim_end_matches('\n');
90 let excerpt = prompt_inputs.cursor_excerpt.as_ref();
91 let (editable_region_offset, _) = excerpt
92 .match_indices(old_text_trimmed)
93 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
94 .with_context(|| {
95 format!(
96 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
97 old_text_trimmed, excerpt
98 )
99 })?;
100
101 let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
102
103 // Use full context so cursor offset (relative to editable region start) aligns with diff content
104 let editable_region_lines = old_text_normalized.lines().count() as u32;
105 let diff = language::unified_diff_with_context(
106 &old_text_normalized,
107 &new_text,
108 editable_region_start_line as u32,
109 editable_region_start_line as u32,
110 editable_region_lines,
111 );
112
113 let formatted_diff = format!(
114 "--- a/{path}\n+++ b/{path}\n{diff}",
115 path = example.spec.cursor_path.to_string_lossy(),
116 );
117
118 let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
119
120 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
121 ActualCursor::from_editable_region(
122 &example.spec.cursor_path,
123 editable_region_cursor_offset,
124 &new_text,
125 excerpt,
126 editable_region_offset,
127 editable_region_start_line,
128 )
129 });
130
131 Ok((formatted_diff, actual_cursor))
132}