1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5 repair,
6};
7use anyhow::{Context as _, Result};
8use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
9
10pub fn run_parse_output(example: &mut Example) -> Result<()> {
11 example
12 .prompt_inputs
13 .as_ref()
14 .context("prompt_inputs required")?;
15
16 let to_parse: Vec<_> = example
17 .predictions
18 .iter()
19 .enumerate()
20 .filter(|(_, p)| !p.actual_output.is_empty())
21 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
22 .collect();
23
24 for (ix, actual_output, provider) in to_parse {
25 let (actual_patch, actual_cursor) =
26 parse_prediction_output(example, &actual_output, provider)?;
27 example.predictions[ix].actual_patch = Some(actual_patch);
28 example.predictions[ix].actual_cursor = actual_cursor;
29 }
30
31 Ok(())
32}
33
34pub fn parse_prediction_output(
35 example: &Example,
36 actual_output: &str,
37 provider: PredictionProvider,
38) -> Result<(String, Option<ActualCursor>)> {
39 match provider {
40 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
41 TeacherPrompt::parse(example, actual_output)
42 }
43 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
44 PredictionProvider::Repair => repair::parse(example, actual_output),
45 _ => anyhow::bail!(
46 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
47 provider
48 ),
49 }
50}
51
52fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
53 let (current_marker, end_marker) = match format {
54 ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
55 ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
56 ("<|fim_middle|>current\n", "<|fim_suffix|>")
57 }
58 ZetaFormat::V0120GitMergeMarkers
59 | ZetaFormat::V0131GitMergeMarkersPrefix
60 | ZetaFormat::V0211Prefill => (
61 zeta_prompt::v0120_git_merge_markers::START_MARKER,
62 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
63 ),
64 ZetaFormat::V0211SeedCoder => (
65 zeta_prompt::seed_coder::START_MARKER,
66 zeta_prompt::seed_coder::SEPARATOR,
67 ),
68 };
69
70 let start = prompt.find(current_marker).with_context(|| {
71 format!(
72 "missing current marker '{}' in prompt",
73 current_marker.trim()
74 )
75 })? + current_marker.len();
76
77 let end = prompt[start..]
78 .find(end_marker)
79 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
80 + start;
81
82 let region = &prompt[start..end];
83 let region = region.replace(CURSOR_MARKER, "");
84
85 Ok(region)
86}
87
88fn parse_zeta2_output(
89 example: &Example,
90 actual_output: &str,
91 format: ZetaFormat,
92) -> Result<(String, Option<ActualCursor>)> {
93 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
94 let prompt_inputs = example
95 .prompt_inputs
96 .as_ref()
97 .context("prompt_inputs required")?;
98
99 let old_text = extract_zeta2_current_region(prompt, format)?;
100
101 let mut new_text = actual_output.to_string();
102 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
103 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
104 Some(offset)
105 } else {
106 None
107 };
108
109 let suffix = match format {
110 ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
111 zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
112 }
113 ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
114 ZetaFormat::V0112MiddleAtEnd
115 | ZetaFormat::V0113Ordered
116 | ZetaFormat::V0114180EditableRegion => "",
117 ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER,
118 };
119 if !suffix.is_empty() {
120 new_text = new_text
121 .strip_suffix(suffix)
122 .unwrap_or(&new_text)
123 .to_string();
124 }
125
126 let mut old_text_normalized = old_text.clone();
127 if !new_text.is_empty() && !new_text.ends_with('\n') {
128 new_text.push('\n');
129 }
130 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
131 old_text_normalized.push('\n');
132 }
133
134 let old_text_trimmed = old_text.trim_end_matches('\n');
135 let (editable_region_offset, _) = prompt_inputs
136 .content
137 .match_indices(old_text_trimmed)
138 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
139 .with_context(|| {
140 format!(
141 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
142 old_text_trimmed, &prompt_inputs.content
143 )
144 })?;
145
146 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
147 .matches('\n')
148 .count();
149
150 // Use full context so cursor offset (relative to editable region start) aligns with diff content
151 let editable_region_lines = old_text_normalized.lines().count() as u32;
152 let diff = language::unified_diff_with_context(
153 &old_text_normalized,
154 &new_text,
155 editable_region_start_line as u32,
156 editable_region_start_line as u32,
157 editable_region_lines,
158 );
159
160 let formatted_diff = format!(
161 "--- a/{path}\n+++ b/{path}\n{diff}",
162 path = example.spec.cursor_path.to_string_lossy(),
163 );
164
165 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
166 ActualCursor::from_editable_region(
167 &example.spec.cursor_path,
168 editable_region_cursor_offset,
169 &new_text,
170 &prompt_inputs.content,
171 editable_region_offset,
172 editable_region_start_line,
173 )
174 });
175
176 Ok((formatted_diff, actual_cursor))
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_extract_zeta2_current_region_v0113() {
185 let prompt = indoc::indoc! {"
186 <|file_sep|>src/main.rs
187 <|fim_prefix|>
188 fn main() {
189 <|fim_middle|>current
190 println!(\"hello\");
191 <|fim_suffix|>
192 }
193 <|fim_middle|>updated
194 "};
195
196 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
197 assert_eq!(region, "println!(\"hello\");\n");
198 }
199
200 #[test]
201 fn test_extract_zeta2_current_region_v0112() {
202 let prompt = indoc::indoc! {"
203 <|file_sep|>src/main.rs
204 <|fim_prefix|>
205 fn main() {
206 <|fim_suffix|>
207 }
208 <|fim_middle|>current
209 println!(\"hello\");
210 <|fim_middle|>updated
211 "};
212
213 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
214 assert_eq!(region, "println!(\"hello\");\n");
215 }
216
217 #[test]
218 fn test_extract_zeta2_current_region_with_cursor_marker() {
219 let prompt = indoc::indoc! {"
220 <|file_sep|>src/main.rs
221 <|fim_prefix|>
222 fn main() {
223 <|fim_middle|>current
224 print<|user_cursor|>ln!(\"hello\");
225 <|fim_suffix|>
226 }
227 <|fim_middle|>updated
228 "};
229
230 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
231 assert_eq!(region, "println!(\"hello\");\n");
232 }
233
234 #[test]
235 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
236 let prompt = indoc::indoc! {"
237 <|file_sep|>src/main.rs
238 <|fim_prefix|>
239 fn main() {
240 <|fim_suffix|>
241 }
242 <|fim_middle|><<<<<<< CURRENT
243 println!(\"hello\");
244 =======
245 "};
246
247 let region =
248 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
249 assert_eq!(region, "println!(\"hello\");\n");
250 }
251
252 #[test]
253 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
254 let prompt = indoc::indoc! {"
255 <|file_sep|>src/main.rs
256 <|fim_prefix|>
257 fn main() {
258 <|fim_suffix|>
259 }
260 <|fim_middle|><<<<<<< CURRENT
261 print<|user_cursor|>ln!(\"hello\");
262 =======
263 "};
264
265 let region =
266 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
267 assert_eq!(region, "println!(\"hello\");\n");
268 }
269}