1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5};
6use anyhow::{Context as _, Result};
7use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
8
9pub fn run_parse_output(example: &mut Example) -> Result<()> {
10 let provider = example
11 .prompt
12 .as_ref()
13 .context("prompt required (run format-prompt first)")?
14 .provider;
15 example
16 .prompt_inputs
17 .as_ref()
18 .context("prompt_inputs required")?;
19
20 let parsed_patches: Vec<_> = example
21 .predictions
22 .iter()
23 .enumerate()
24 .filter(|(_, p)| !p.actual_output.is_empty())
25 .map(|(ix, prediction)| {
26 let result = parse_prediction_output(example, &prediction.actual_output, provider);
27 result.map(|(patch, cursor_offset)| (ix, patch, cursor_offset))
28 })
29 .collect::<Result<Vec<_>>>()?;
30
31 for (ix, actual_patch, actual_cursor) in parsed_patches {
32 example.predictions[ix].actual_patch = Some(actual_patch);
33 example.predictions[ix].actual_cursor = actual_cursor;
34 example.predictions[ix].provider = provider;
35 }
36
37 Ok(())
38}
39
40pub fn parse_prediction_output(
41 example: &Example,
42 actual_output: &str,
43 provider: PredictionProvider,
44) -> Result<(String, Option<ActualCursor>)> {
45 match provider {
46 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
47 TeacherPrompt::parse(example, actual_output)
48 }
49 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
50 _ => anyhow::bail!(
51 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
52 provider
53 ),
54 }
55}
56
57fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
58 let (current_marker, end_marker) = match version {
59 ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
60 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
61 ("<|fim_middle|>current\n", "<|fim_suffix|>")
62 }
63 ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
64 zeta_prompt::v0120_git_merge_markers::START_MARKER,
65 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
66 ),
67 };
68
69 let start = prompt.find(current_marker).with_context(|| {
70 format!(
71 "missing current marker '{}' in prompt",
72 current_marker.trim()
73 )
74 })? + current_marker.len();
75
76 let end = prompt[start..]
77 .find(end_marker)
78 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
79 + start;
80
81 let region = &prompt[start..end];
82 let region = region.replace(CURSOR_MARKER, "");
83
84 Ok(region)
85}
86
87fn parse_zeta2_output(
88 example: &Example,
89 actual_output: &str,
90 version: ZetaVersion,
91) -> Result<(String, Option<ActualCursor>)> {
92 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
93 let prompt_inputs = example
94 .prompt_inputs
95 .as_ref()
96 .context("prompt_inputs required")?;
97
98 let old_text = extract_zeta2_current_region(prompt, version)?;
99
100 let mut new_text = actual_output.to_string();
101 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
102 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
103 Some(offset)
104 } else {
105 None
106 };
107
108 let suffix = match version {
109 ZetaVersion::V0131GitMergeMarkersPrefix => {
110 zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
111 }
112 ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
113 _ => "",
114 };
115 if !suffix.is_empty() {
116 new_text = new_text
117 .strip_suffix(suffix)
118 .unwrap_or(&new_text)
119 .to_string();
120 }
121
122 let mut old_text_normalized = old_text.clone();
123 if !new_text.is_empty() && !new_text.ends_with('\n') {
124 new_text.push('\n');
125 }
126 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
127 old_text_normalized.push('\n');
128 }
129
130 let old_text_trimmed = old_text.trim_end_matches('\n');
131 let (editable_region_offset, _) = prompt_inputs
132 .content
133 .match_indices(old_text_trimmed)
134 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
135 .with_context(|| {
136 format!(
137 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
138 old_text_trimmed, &prompt_inputs.content
139 )
140 })?;
141
142 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
143 .matches('\n')
144 .count();
145
146 // Use full context so cursor offset (relative to editable region start) aligns with diff content
147 let editable_region_lines = old_text_normalized.lines().count() as u32;
148 let diff = language::unified_diff_with_context(
149 &old_text_normalized,
150 &new_text,
151 editable_region_start_line as u32,
152 editable_region_start_line as u32,
153 editable_region_lines,
154 );
155
156 let formatted_diff = format!(
157 "--- a/{path}\n+++ b/{path}\n{diff}",
158 path = example.spec.cursor_path.to_string_lossy(),
159 );
160
161 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
162 ActualCursor::from_editable_region(
163 &example.spec.cursor_path,
164 editable_region_cursor_offset,
165 &new_text,
166 &prompt_inputs.content,
167 editable_region_offset,
168 editable_region_start_line,
169 )
170 });
171
172 Ok((formatted_diff, actual_cursor))
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn test_extract_zeta2_current_region_v0113() {
181 let prompt = indoc::indoc! {"
182 <|file_sep|>src/main.rs
183 <|fim_prefix|>
184 fn main() {
185 <|fim_middle|>current
186 println!(\"hello\");
187 <|fim_suffix|>
188 }
189 <|fim_middle|>updated
190 "};
191
192 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
193 assert_eq!(region, "println!(\"hello\");\n");
194 }
195
196 #[test]
197 fn test_extract_zeta2_current_region_v0112() {
198 let prompt = indoc::indoc! {"
199 <|file_sep|>src/main.rs
200 <|fim_prefix|>
201 fn main() {
202 <|fim_suffix|>
203 }
204 <|fim_middle|>current
205 println!(\"hello\");
206 <|fim_middle|>updated
207 "};
208
209 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
210 assert_eq!(region, "println!(\"hello\");\n");
211 }
212
213 #[test]
214 fn test_extract_zeta2_current_region_with_cursor_marker() {
215 let prompt = indoc::indoc! {"
216 <|file_sep|>src/main.rs
217 <|fim_prefix|>
218 fn main() {
219 <|fim_middle|>current
220 print<|user_cursor|>ln!(\"hello\");
221 <|fim_suffix|>
222 }
223 <|fim_middle|>updated
224 "};
225
226 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
227 assert_eq!(region, "println!(\"hello\");\n");
228 }
229
230 #[test]
231 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
232 let prompt = indoc::indoc! {"
233 <|file_sep|>src/main.rs
234 <|fim_prefix|>
235 fn main() {
236 <|fim_suffix|>
237 }
238 <|fim_middle|><<<<<<< CURRENT
239 println!(\"hello\");
240 =======
241 "};
242
243 let region =
244 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
245 assert_eq!(region, "println!(\"hello\");\n");
246 }
247
248 #[test]
249 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
250 let prompt = indoc::indoc! {"
251 <|file_sep|>src/main.rs
252 <|fim_prefix|>
253 fn main() {
254 <|fim_suffix|>
255 }
256 <|fim_middle|><<<<<<< CURRENT
257 print<|user_cursor|>ln!(\"hello\");
258 =======
259 "};
260
261 let region =
262 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
263 assert_eq!(region, "println!(\"hello\");\n");
264 }
265}