1use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt};
2use anyhow::{Context as _, Result};
3use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
4
5pub fn run_parse_output(example: &mut Example) -> Result<()> {
6 let provider = example
7 .prompt
8 .as_ref()
9 .context("prompt required (run format-prompt first)")?
10 .provider;
11 example
12 .prompt_inputs
13 .as_ref()
14 .context("prompt_inputs required")?;
15
16 let parsed_patches: Vec<_> = example
17 .predictions
18 .iter()
19 .enumerate()
20 .filter(|(_, p)| !p.actual_output.is_empty())
21 .map(|(ix, prediction)| {
22 let result = parse_prediction_output(example, &prediction.actual_output, provider);
23 result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
24 })
25 .collect::<Result<Vec<_>>>()?;
26
27 for (ix, actual_patch, actual_cursor_offset) in parsed_patches {
28 example.predictions[ix].actual_patch = Some(actual_patch);
29 example.predictions[ix].actual_cursor_offset = actual_cursor_offset;
30 example.predictions[ix].provider = provider;
31 }
32
33 Ok(())
34}
35
36pub fn parse_prediction_output(
37 example: &Example,
38 actual_output: &str,
39 provider: PredictionProvider,
40) -> Result<(String, Option<usize>)> {
41 match provider {
42 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
43 TeacherPrompt::parse(example, actual_output)
44 }
45 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
46 _ => anyhow::bail!(
47 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
48 provider
49 ),
50 }
51}
52
53fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
54 let (current_marker, end_marker) = match version {
55 ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
56 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
57 ("<|fim_middle|>current\n", "<|fim_suffix|>")
58 }
59 ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
60 zeta_prompt::v0120_git_merge_markers::START_MARKER,
61 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
62 ),
63 };
64
65 let start = prompt.find(current_marker).with_context(|| {
66 format!(
67 "missing current marker '{}' in prompt",
68 current_marker.trim()
69 )
70 })? + current_marker.len();
71
72 let end = prompt[start..]
73 .find(end_marker)
74 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
75 + start;
76
77 let region = &prompt[start..end];
78 let region = region.replace(CURSOR_MARKER, "");
79
80 Ok(region)
81}
82
83fn parse_zeta2_output(
84 example: &Example,
85 actual_output: &str,
86 version: ZetaVersion,
87) -> Result<(String, Option<usize>)> {
88 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
89 let prompt_inputs = example
90 .prompt_inputs
91 .as_ref()
92 .context("prompt_inputs required")?;
93
94 let old_text = extract_zeta2_current_region(prompt, version)?;
95
96 let mut new_text = actual_output.to_string();
97 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
98 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
99 Some(offset)
100 } else {
101 None
102 };
103
104 let suffix = match version {
105 ZetaVersion::V0131GitMergeMarkersPrefix => {
106 zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
107 }
108 ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
109 _ => "",
110 };
111 if !suffix.is_empty() {
112 new_text = new_text
113 .strip_suffix(suffix)
114 .unwrap_or(&new_text)
115 .to_string();
116 }
117
118 let mut old_text_normalized = old_text.clone();
119 if !new_text.is_empty() && !new_text.ends_with('\n') {
120 new_text.push('\n');
121 }
122 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
123 old_text_normalized.push('\n');
124 }
125
126 let old_text_trimmed = old_text.trim_end_matches('\n');
127 let (editable_region_offset, _) = prompt_inputs
128 .content
129 .match_indices(old_text_trimmed)
130 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
131 .with_context(|| {
132 format!(
133 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
134 old_text_trimmed, &prompt_inputs.content
135 )
136 })?;
137
138 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
139 .matches('\n')
140 .count();
141
142 // Use full context so cursor offset (relative to editable region start) aligns with diff content
143 let editable_region_lines = old_text_normalized.lines().count() as u32;
144 let diff = language::unified_diff_with_context(
145 &old_text_normalized,
146 &new_text,
147 editable_region_start_line as u32,
148 editable_region_start_line as u32,
149 editable_region_lines,
150 );
151
152 let formatted_diff = format!(
153 "--- a/{path}\n+++ b/{path}\n{diff}",
154 path = example.spec.cursor_path.to_string_lossy(),
155 );
156
157 Ok((formatted_diff, cursor_offset))
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_extract_zeta2_current_region_v0113() {
166 let prompt = indoc::indoc! {"
167 <|file_sep|>src/main.rs
168 <|fim_prefix|>
169 fn main() {
170 <|fim_middle|>current
171 println!(\"hello\");
172 <|fim_suffix|>
173 }
174 <|fim_middle|>updated
175 "};
176
177 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
178 assert_eq!(region, "println!(\"hello\");\n");
179 }
180
181 #[test]
182 fn test_extract_zeta2_current_region_v0112() {
183 let prompt = indoc::indoc! {"
184 <|file_sep|>src/main.rs
185 <|fim_prefix|>
186 fn main() {
187 <|fim_suffix|>
188 }
189 <|fim_middle|>current
190 println!(\"hello\");
191 <|fim_middle|>updated
192 "};
193
194 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
195 assert_eq!(region, "println!(\"hello\");\n");
196 }
197
198 #[test]
199 fn test_extract_zeta2_current_region_with_cursor_marker() {
200 let prompt = indoc::indoc! {"
201 <|file_sep|>src/main.rs
202 <|fim_prefix|>
203 fn main() {
204 <|fim_middle|>current
205 print<|user_cursor|>ln!(\"hello\");
206 <|fim_suffix|>
207 }
208 <|fim_middle|>updated
209 "};
210
211 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
212 assert_eq!(region, "println!(\"hello\");\n");
213 }
214
215 #[test]
216 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
217 let prompt = indoc::indoc! {"
218 <|file_sep|>src/main.rs
219 <|fim_prefix|>
220 fn main() {
221 <|fim_suffix|>
222 }
223 <|fim_middle|><<<<<<< CURRENT
224 println!(\"hello\");
225 =======
226 "};
227
228 let region =
229 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
230 assert_eq!(region, "println!(\"hello\");\n");
231 }
232
233 #[test]
234 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
235 let prompt = indoc::indoc! {"
236 <|file_sep|>src/main.rs
237 <|fim_prefix|>
238 fn main() {
239 <|fim_suffix|>
240 }
241 <|fim_middle|><<<<<<< CURRENT
242 print<|user_cursor|>ln!(\"hello\");
243 =======
244 "};
245
246 let region =
247 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
248 assert_eq!(region, "println!(\"hello\");\n");
249 }
250}