1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5 repair,
6};
7use anyhow::{Context as _, Result};
8use edit_prediction::example_spec::encode_cursor_in_patch;
9use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
10
11pub fn run_parse_output(example: &mut Example) -> Result<()> {
12 example
13 .prompt_inputs
14 .as_ref()
15 .context("prompt_inputs required")?;
16
17 let to_parse: Vec<_> = example
18 .predictions
19 .iter()
20 .enumerate()
21 .filter(|(_, p)| !p.actual_output.is_empty())
22 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
23 .collect();
24
25 for (ix, actual_output, provider) in to_parse {
26 let (actual_patch, actual_cursor) =
27 parse_prediction_output(example, &actual_output, provider)?;
28 example.predictions[ix].actual_patch = Some(actual_patch);
29 example.predictions[ix].actual_cursor = actual_cursor;
30 }
31
32 Ok(())
33}
34
35pub fn parse_prediction_output(
36 example: &Example,
37 actual_output: &str,
38 provider: PredictionProvider,
39) -> Result<(String, Option<ActualCursor>)> {
40 match provider {
41 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
42 TeacherPrompt::parse(example, actual_output)
43 }
44 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
45 PredictionProvider::Repair => repair::parse(example, actual_output),
46 _ => anyhow::bail!(
47 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
48 provider
49 ),
50 }
51}
52
53fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
54 let (current_marker, end_marker) = match format {
55 ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
56 ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
57 ("<|fim_middle|>current\n", "<|fim_suffix|>")
58 }
59 ZetaFormat::V0120GitMergeMarkers
60 | ZetaFormat::V0131GitMergeMarkersPrefix
61 | ZetaFormat::V0211Prefill => (
62 zeta_prompt::v0120_git_merge_markers::START_MARKER,
63 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
64 ),
65 ZetaFormat::V0211SeedCoder => (
66 zeta_prompt::seed_coder::START_MARKER,
67 zeta_prompt::seed_coder::SEPARATOR,
68 ),
69 };
70
71 let start = prompt.find(current_marker).with_context(|| {
72 format!(
73 "missing current marker '{}' in prompt",
74 current_marker.trim()
75 )
76 })? + current_marker.len();
77
78 let end = prompt[start..]
79 .find(end_marker)
80 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
81 + start;
82
83 let region = &prompt[start..end];
84 let region = region.replace(CURSOR_MARKER, "");
85
86 Ok(region)
87}
88
89fn parse_zeta2_output(
90 example: &Example,
91 actual_output: &str,
92 format: ZetaFormat,
93) -> Result<(String, Option<ActualCursor>)> {
94 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
95 let prompt_inputs = example
96 .prompt_inputs
97 .as_ref()
98 .context("prompt_inputs required")?;
99
100 let old_text = extract_zeta2_current_region(prompt, format)?;
101
102 let mut new_text = actual_output.to_string();
103 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
104 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
105 Some(offset)
106 } else {
107 None
108 };
109
110 let suffix = match format {
111 ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
112 zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
113 }
114 ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
115 ZetaFormat::V0112MiddleAtEnd
116 | ZetaFormat::V0113Ordered
117 | ZetaFormat::V0114180EditableRegion => "",
118 ZetaFormat::V0211SeedCoder => zeta_prompt::seed_coder::END_MARKER,
119 };
120 if !suffix.is_empty() {
121 new_text = new_text
122 .strip_suffix(suffix)
123 .unwrap_or(&new_text)
124 .to_string();
125 }
126
127 let mut old_text_normalized = old_text.clone();
128 if !new_text.is_empty() && !new_text.ends_with('\n') {
129 new_text.push('\n');
130 }
131 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
132 old_text_normalized.push('\n');
133 }
134
135 let old_text_trimmed = old_text.trim_end_matches('\n');
136 let (editable_region_offset, _) = prompt_inputs
137 .content
138 .match_indices(old_text_trimmed)
139 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
140 .with_context(|| {
141 format!(
142 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
143 old_text_trimmed, &prompt_inputs.content
144 )
145 })?;
146
147 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
148 .matches('\n')
149 .count();
150
151 // Use full context so cursor offset (relative to editable region start) aligns with diff content
152 let editable_region_lines = old_text_normalized.lines().count() as u32;
153 let diff = language::unified_diff_with_context(
154 &old_text_normalized,
155 &new_text,
156 editable_region_start_line as u32,
157 editable_region_start_line as u32,
158 editable_region_lines,
159 );
160
161 let formatted_diff = format!(
162 "--- a/{path}\n+++ b/{path}\n{diff}",
163 path = example.spec.cursor_path.to_string_lossy(),
164 );
165
166 let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
167
168 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
169 ActualCursor::from_editable_region(
170 &example.spec.cursor_path,
171 editable_region_cursor_offset,
172 &new_text,
173 &prompt_inputs.content,
174 editable_region_offset,
175 editable_region_start_line,
176 )
177 });
178
179 Ok((formatted_diff, actual_cursor))
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_extract_zeta2_current_region_v0113() {
188 let prompt = indoc::indoc! {"
189 <|file_sep|>src/main.rs
190 <|fim_prefix|>
191 fn main() {
192 <|fim_middle|>current
193 println!(\"hello\");
194 <|fim_suffix|>
195 }
196 <|fim_middle|>updated
197 "};
198
199 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
200 assert_eq!(region, "println!(\"hello\");\n");
201 }
202
203 #[test]
204 fn test_extract_zeta2_current_region_v0112() {
205 let prompt = indoc::indoc! {"
206 <|file_sep|>src/main.rs
207 <|fim_prefix|>
208 fn main() {
209 <|fim_suffix|>
210 }
211 <|fim_middle|>current
212 println!(\"hello\");
213 <|fim_middle|>updated
214 "};
215
216 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
217 assert_eq!(region, "println!(\"hello\");\n");
218 }
219
220 #[test]
221 fn test_extract_zeta2_current_region_with_cursor_marker() {
222 let prompt = indoc::indoc! {"
223 <|file_sep|>src/main.rs
224 <|fim_prefix|>
225 fn main() {
226 <|fim_middle|>current
227 print<|user_cursor|>ln!(\"hello\");
228 <|fim_suffix|>
229 }
230 <|fim_middle|>updated
231 "};
232
233 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
234 assert_eq!(region, "println!(\"hello\");\n");
235 }
236
237 #[test]
238 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
239 let prompt = indoc::indoc! {"
240 <|file_sep|>src/main.rs
241 <|fim_prefix|>
242 fn main() {
243 <|fim_suffix|>
244 }
245 <|fim_middle|><<<<<<< CURRENT
246 println!(\"hello\");
247 =======
248 "};
249
250 let region =
251 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
252 assert_eq!(region, "println!(\"hello\");\n");
253 }
254
255 #[test]
256 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
257 let prompt = indoc::indoc! {"
258 <|file_sep|>src/main.rs
259 <|fim_prefix|>
260 fn main() {
261 <|fim_suffix|>
262 }
263 <|fim_middle|><<<<<<< CURRENT
264 print<|user_cursor|>ln!(\"hello\");
265 =======
266 "};
267
268 let region =
269 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
270 assert_eq!(region, "println!(\"hello\");\n");
271 }
272}