1use crate::{
2 PredictionProvider,
3 example::{ActualCursor, Example},
4 format_prompt::TeacherPrompt,
5 repair,
6};
7use anyhow::{Context as _, Result};
8use edit_prediction::example_spec::encode_cursor_in_patch;
9use zeta_prompt::{
10 CURSOR_MARKER, ZetaFormat, clean_extracted_region_for_format,
11 current_region_markers_for_format, output_end_marker_for_format,
12 output_with_context_for_format,
13};
14
15pub fn run_parse_output(example: &mut Example) -> Result<()> {
16 example
17 .prompt_inputs
18 .as_ref()
19 .context("prompt_inputs required")?;
20
21 let to_parse: Vec<_> = example
22 .predictions
23 .iter()
24 .enumerate()
25 .filter(|(_, p)| !p.actual_output.is_empty())
26 .map(|(ix, p)| (ix, p.actual_output.clone(), p.provider))
27 .collect();
28
29 for (ix, actual_output, provider) in to_parse {
30 let (actual_patch, actual_cursor) =
31 parse_prediction_output(example, &actual_output, provider)?;
32 example.predictions[ix].actual_patch = Some(actual_patch);
33 example.predictions[ix].actual_cursor = actual_cursor;
34 }
35
36 Ok(())
37}
38
39pub fn parse_prediction_output(
40 example: &Example,
41 actual_output: &str,
42 provider: PredictionProvider,
43) -> Result<(String, Option<ActualCursor>)> {
44 match provider {
45 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
46 TeacherPrompt::parse(example, actual_output)
47 }
48 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
49 PredictionProvider::Repair => repair::parse(example, actual_output),
50 _ => anyhow::bail!(
51 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
52 provider
53 ),
54 }
55}
56
57fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
58 let (current_marker, end_marker) = current_region_markers_for_format(format);
59
60 let start = prompt.find(current_marker).with_context(|| {
61 format!(
62 "missing current marker '{}' in prompt",
63 current_marker.trim()
64 )
65 })? + current_marker.len();
66
67 let end = prompt[start..]
68 .find(end_marker)
69 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
70 + start;
71
72 let region = &prompt[start..end];
73 let region = region.replace(CURSOR_MARKER, "");
74 Ok(clean_extracted_region_for_format(format, ®ion))
75}
76
77fn parse_zeta2_output(
78 example: &Example,
79 actual_output: &str,
80 format: ZetaFormat,
81) -> Result<(String, Option<ActualCursor>)> {
82 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
83 let prompt_inputs = example
84 .prompt_inputs
85 .as_ref()
86 .context("prompt_inputs required")?;
87
88 let old_text = extract_zeta2_current_region(prompt, format)?;
89
90 let mut new_text = actual_output.to_string();
91 if let Some(transformed) = output_with_context_for_format(format, &old_text, &new_text)? {
92 new_text = transformed;
93 }
94 let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
95 new_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
96 Some(offset)
97 } else {
98 None
99 };
100
101 if let Some(marker) = output_end_marker_for_format(format) {
102 new_text = new_text
103 .strip_suffix(marker)
104 .unwrap_or(&new_text)
105 .to_string();
106 }
107
108 let mut old_text_normalized = old_text.clone();
109 if !new_text.is_empty() && !new_text.ends_with('\n') {
110 new_text.push('\n');
111 }
112 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
113 old_text_normalized.push('\n');
114 }
115
116 let old_text_trimmed = old_text.trim_end_matches('\n');
117 let excerpt = prompt_inputs.cursor_excerpt.as_ref();
118 let (editable_region_offset, _) = excerpt
119 .match_indices(old_text_trimmed)
120 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
121 .with_context(|| {
122 format!(
123 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
124 old_text_trimmed, excerpt
125 )
126 })?;
127
128 let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
129
130 // Use full context so cursor offset (relative to editable region start) aligns with diff content
131 let editable_region_lines = old_text_normalized.lines().count() as u32;
132 let diff = language::unified_diff_with_context(
133 &old_text_normalized,
134 &new_text,
135 editable_region_start_line as u32,
136 editable_region_start_line as u32,
137 editable_region_lines,
138 );
139
140 let formatted_diff = format!(
141 "--- a/{path}\n+++ b/{path}\n{diff}",
142 path = example.spec.cursor_path.to_string_lossy(),
143 );
144
145 let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
146
147 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
148 ActualCursor::from_editable_region(
149 &example.spec.cursor_path,
150 editable_region_cursor_offset,
151 &new_text,
152 excerpt,
153 editable_region_offset,
154 editable_region_start_line,
155 )
156 });
157
158 Ok((formatted_diff, actual_cursor))
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[test]
166 fn test_extract_zeta2_current_region_v0113() {
167 let prompt = indoc::indoc! {"
168 <|file_sep|>src/main.rs
169 <|fim_prefix|>
170 fn main() {
171 <|fim_middle|>current
172 println!(\"hello\");
173 <|fim_suffix|>
174 }
175 <|fim_middle|>updated
176 "};
177
178 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
179 assert_eq!(region, "println!(\"hello\");\n");
180 }
181
182 #[test]
183 fn test_extract_zeta2_current_region_v0112() {
184 let prompt = indoc::indoc! {"
185 <|file_sep|>src/main.rs
186 <|fim_prefix|>
187 fn main() {
188 <|fim_suffix|>
189 }
190 <|fim_middle|>current
191 println!(\"hello\");
192 <|fim_middle|>updated
193 "};
194
195 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
196 assert_eq!(region, "println!(\"hello\");\n");
197 }
198
199 #[test]
200 fn test_extract_zeta2_current_region_with_cursor_marker() {
201 let prompt = indoc::indoc! {"
202 <|file_sep|>src/main.rs
203 <|fim_prefix|>
204 fn main() {
205 <|fim_middle|>current
206 print<|user_cursor|>ln!(\"hello\");
207 <|fim_suffix|>
208 }
209 <|fim_middle|>updated
210 "};
211
212 let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
213 assert_eq!(region, "println!(\"hello\");\n");
214 }
215
216 #[test]
217 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
218 let prompt = indoc::indoc! {"
219 <|file_sep|>src/main.rs
220 <|fim_prefix|>
221 fn main() {
222 <|fim_suffix|>
223 }
224 <|fim_middle|><<<<<<< CURRENT
225 println!(\"hello\");
226 =======
227 "};
228
229 let region =
230 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
231 assert_eq!(region, "println!(\"hello\");\n");
232 }
233
234 #[test]
235 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
236 let prompt = indoc::indoc! {"
237 <|file_sep|>src/main.rs
238 <|fim_prefix|>
239 fn main() {
240 <|fim_suffix|>
241 }
242 <|fim_middle|><<<<<<< CURRENT
243 print<|user_cursor|>ln!(\"hello\");
244 =======
245 "};
246
247 let region =
248 extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
249 assert_eq!(region, "println!(\"hello\");\n");
250 }
251}