1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5 repair,
6};
7use anyhow::{Context as _, Result};
8use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
9
10pub fn run_parse_output(example: &mut Example) -> Result<()> {
11 example
12 .prompt_inputs
13 .as_ref()
14 .context("prompt_inputs required")?;
15
16 let to_parse: Vec<_> = example
17 .predictions
18 .iter()
19 .enumerate()
20 .filter(|(_, p)| !p.actual_output.is_empty())
21 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
22 .collect();
23
24 for (ix, actual_output, provider) in to_parse {
25 let (actual_patch, actual_cursor) =
26 parse_prediction_output(example, &actual_output, provider)?;
27 example.predictions[ix].actual_patch = Some(actual_patch);
28 example.predictions[ix].actual_cursor = actual_cursor;
29 }
30
31 Ok(())
32}
33
34pub fn parse_prediction_output(
35 example: &Example,
36 actual_output: &str,
37 provider: PredictionProvider,
38) -> Result<(String, Option<ActualCursor>)> {
39 match provider {
40 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
41 TeacherPrompt::parse(example, actual_output)
42 }
43 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
44 PredictionProvider::Repair => repair::parse(example, actual_output),
45 _ => anyhow::bail!(
46 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
47 provider
48 ),
49 }
50}
51
52fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
53 let (current_marker, end_marker) = match format {
54 ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
55 ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
56 ("<|fim_middle|>current\n", "<|fim_suffix|>")
57 }
58 ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => (
59 zeta_prompt::v0120_git_merge_markers::START_MARKER,
60 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
61 ),
62 };
63
64 let start = prompt.find(current_marker).with_context(|| {
65 format!(
66 "missing current marker '{}' in prompt",
67 current_marker.trim()
68 )
69 })? + current_marker.len();
70
71 let end = prompt[start..]
72 .find(end_marker)
73 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
74 + start;
75
76 let region = &prompt[start..end];
77 let region = region.replace(CURSOR_MARKER, "");
78
79 Ok(region)
80}
81
82fn parse_zeta2_output(
83 example: &Example,
84 actual_output: &str,
85 format: ZetaFormat,
86) -> Result<(String, Option<ActualCursor>)> {
87 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
88 let prompt_inputs = example
89 .prompt_inputs
90 .as_ref()
91 .context("prompt_inputs required")?;
92
93 let old_text = extract_zeta2_current_region(prompt, format)?;
94
95 let mut new_text = actual_output.to_string();
96 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
97 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
98 Some(offset)
99 } else {
100 None
101 };
102
103 let suffix = match format {
104 ZetaFormat::V0131GitMergeMarkersPrefix => {
105 zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
106 }
107 ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
108 _ => "",
109 };
110 if !suffix.is_empty() {
111 new_text = new_text
112 .strip_suffix(suffix)
113 .unwrap_or(&new_text)
114 .to_string();
115 }
116
117 let mut old_text_normalized = old_text.clone();
118 if !new_text.is_empty() && !new_text.ends_with('\n') {
119 new_text.push('\n');
120 }
121 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
122 old_text_normalized.push('\n');
123 }
124
125 let old_text_trimmed = old_text.trim_end_matches('\n');
126 let (editable_region_offset, _) = prompt_inputs
127 .content
128 .match_indices(old_text_trimmed)
129 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
130 .with_context(|| {
131 format!(
132 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
133 old_text_trimmed, &prompt_inputs.content
134 )
135 })?;
136
137 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
138 .matches('\n')
139 .count();
140
141 // Use full context so cursor offset (relative to editable region start) aligns with diff content
142 let editable_region_lines = old_text_normalized.lines().count() as u32;
143 let diff = language::unified_diff_with_context(
144 &old_text_normalized,
145 &new_text,
146 editable_region_start_line as u32,
147 editable_region_start_line as u32,
148 editable_region_lines,
149 );
150
151 let formatted_diff = format!(
152 "--- a/{path}\n+++ b/{path}\n{diff}",
153 path = example.spec.cursor_path.to_string_lossy(),
154 );
155
156 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
157 ActualCursor::from_editable_region(
158 &example.spec.cursor_path,
159 editable_region_cursor_offset,
160 &new_text,
161 &prompt_inputs.content,
162 editable_region_offset,
163 editable_region_start_line,
164 )
165 });
166
167 Ok((formatted_diff, actual_cursor))
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_extract_zeta2_current_region_v0113() {
176 let prompt = indoc::indoc! {"
177 <|file_sep|>src/main.rs
178 <|fim_prefix|>
179 fn main() {
180 <|fim_middle|>current
181 println!(\"hello\");
182 <|fim_suffix|>
183 }
184 <|fim_middle|>updated
185 "};
186
187 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
188 assert_eq!(region, "println!(\"hello\");\n");
189 }
190
191 #[test]
192 fn test_extract_zeta2_current_region_v0112() {
193 let prompt = indoc::indoc! {"
194 <|file_sep|>src/main.rs
195 <|fim_prefix|>
196 fn main() {
197 <|fim_suffix|>
198 }
199 <|fim_middle|>current
200 println!(\"hello\");
201 <|fim_middle|>updated
202 "};
203
204 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
205 assert_eq!(region, "println!(\"hello\");\n");
206 }
207
208 #[test]
209 fn test_extract_zeta2_current_region_with_cursor_marker() {
210 let prompt = indoc::indoc! {"
211 <|file_sep|>src/main.rs
212 <|fim_prefix|>
213 fn main() {
214 <|fim_middle|>current
215 print<|user_cursor|>ln!(\"hello\");
216 <|fim_suffix|>
217 }
218 <|fim_middle|>updated
219 "};
220
221 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
222 assert_eq!(region, "println!(\"hello\");\n");
223 }
224
225 #[test]
226 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
227 let prompt = indoc::indoc! {"
228 <|file_sep|>src/main.rs
229 <|fim_prefix|>
230 fn main() {
231 <|fim_suffix|>
232 }
233 <|fim_middle|><<<<<<< CURRENT
234 println!(\"hello\");
235 =======
236 "};
237
238 let region =
239 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
240 assert_eq!(region, "println!(\"hello\");\n");
241 }
242
243 #[test]
244 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
245 let prompt = indoc::indoc! {"
246 <|file_sep|>src/main.rs
247 <|fim_prefix|>
248 fn main() {
249 <|fim_suffix|>
250 }
251 <|fim_middle|><<<<<<< CURRENT
252 print<|user_cursor|>ln!(\"hello\");
253 =======
254 "};
255
256 let region =
257 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
258 assert_eq!(region, "println!(\"hello\");\n");
259 }
260}