1use crate::{PredictionProvider, example::Example, format_prompt::TeacherPrompt};
2use anyhow::{Context as _, Result};
3use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
4
5pub fn run_parse_output(example: &mut Example) -> Result<()> {
6 let provider = example
7 .prompt
8 .as_ref()
9 .context("prompt required (run format-prompt first)")?
10 .provider;
11 example
12 .prompt_inputs
13 .as_ref()
14 .context("prompt_inputs required")?;
15
16 let parsed_patches: Vec<_> = example
17 .predictions
18 .iter()
19 .enumerate()
20 .filter(|(_, p)| !p.actual_output.is_empty())
21 .map(|(ix, prediction)| {
22 let actual_patch =
23 parse_prediction_output(example, &prediction.actual_output, provider);
24 actual_patch.map(|patch| (ix, patch))
25 })
26 .collect::<Result<Vec<_>>>()?;
27
28 for (ix, actual_patch) in parsed_patches {
29 example.predictions[ix].actual_patch = Some(actual_patch);
30 example.predictions[ix].provider = provider;
31 }
32
33 Ok(())
34}
35
36pub fn parse_prediction_output(
37 example: &Example,
38 actual_output: &str,
39 provider: PredictionProvider,
40) -> Result<String> {
41 match provider {
42 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
43 TeacherPrompt::parse(example, actual_output)
44 }
45 PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version),
46 _ => anyhow::bail!(
47 "parse-output only supports Teacher and Zeta2 providers, got {:?}",
48 provider
49 ),
50 }
51}
52
53fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
54 let (current_marker, end_marker) = match version {
55 ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
56 ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
57 ("<|fim_middle|>current\n", "<|fim_suffix|>")
58 }
59 ZetaVersion::V0120GitMergeMarkers => (
60 zeta_prompt::v0120_git_merge_markers::START_MARKER,
61 zeta_prompt::v0120_git_merge_markers::SEPARATOR,
62 ),
63 };
64
65 let start = prompt.find(current_marker).with_context(|| {
66 format!(
67 "missing current marker '{}' in prompt",
68 current_marker.trim()
69 )
70 })? + current_marker.len();
71
72 let end = prompt[start..]
73 .find(end_marker)
74 .with_context(|| format!("missing end marker '{}' in prompt", end_marker.trim()))?
75 + start;
76
77 let region = &prompt[start..end];
78 let region = region.strip_suffix('\n').unwrap_or(region);
79 Ok(region.replace(CURSOR_MARKER, ""))
80}
81
82fn parse_zeta2_output(
83 example: &Example,
84 actual_output: &str,
85 version: ZetaVersion,
86) -> Result<String> {
87 let prompt = &example.prompt.as_ref().context("prompt required")?.input;
88 let prompt_inputs = example
89 .prompt_inputs
90 .as_ref()
91 .context("prompt_inputs required")?;
92
93 let old_text = extract_zeta2_current_region(prompt, version)?;
94
95 let mut new_text = actual_output.replace(CURSOR_MARKER, "");
96
97 if version == ZetaVersion::V0120GitMergeMarkers {
98 if let Some(stripped) =
99 new_text.strip_suffix(zeta_prompt::v0120_git_merge_markers::END_MARKER)
100 {
101 new_text = stripped.to_string();
102 }
103 }
104
105 let mut old_text_normalized = old_text.clone();
106 if !new_text.is_empty() && !new_text.ends_with('\n') {
107 new_text.push('\n');
108 }
109 if !old_text_normalized.is_empty() && !old_text_normalized.ends_with('\n') {
110 old_text_normalized.push('\n');
111 }
112
113 let old_text_trimmed = old_text.trim_end_matches('\n');
114 let (editable_region_offset, _) = prompt_inputs
115 .content
116 .match_indices(old_text_trimmed)
117 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
118 .with_context(|| {
119 format!(
120 "could not find editable region in content.\nLooking for:\n{}\n\nIn content:\n{}",
121 old_text_trimmed, &prompt_inputs.content
122 )
123 })?;
124
125 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
126 .matches('\n')
127 .count();
128
129 let diff = language::unified_diff_with_offsets(
130 &old_text_normalized,
131 &new_text,
132 editable_region_start_line as u32,
133 editable_region_start_line as u32,
134 );
135
136 let formatted_diff = format!(
137 "--- a/{path}\n+++ b/{path}\n{diff}",
138 path = example.spec.cursor_path.to_string_lossy(),
139 );
140
141 Ok(formatted_diff)
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_extract_zeta2_current_region_v0113() {
150 let prompt = indoc::indoc! {"
151 <|file_sep|>src/main.rs
152 <|fim_prefix|>
153 fn main() {
154 <|fim_middle|>current
155 println!(\"hello\");
156 <|fim_suffix|>
157 }
158 <|fim_middle|>updated
159 "};
160
161 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
162 assert_eq!(region, "println!(\"hello\");");
163 }
164
165 #[test]
166 fn test_extract_zeta2_current_region_v0112() {
167 let prompt = indoc::indoc! {"
168 <|file_sep|>src/main.rs
169 <|fim_prefix|>
170 fn main() {
171 <|fim_suffix|>
172 }
173 <|fim_middle|>current
174 println!(\"hello\");
175 <|fim_middle|>updated
176 "};
177
178 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
179 assert_eq!(region, "println!(\"hello\");");
180 }
181
182 #[test]
183 fn test_extract_zeta2_current_region_with_cursor_marker() {
184 let prompt = indoc::indoc! {"
185 <|file_sep|>src/main.rs
186 <|fim_prefix|>
187 fn main() {
188 <|fim_middle|>current
189 print<|user_cursor|>ln!(\"hello\");
190 <|fim_suffix|>
191 }
192 <|fim_middle|>updated
193 "};
194
195 let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
196 assert_eq!(region, "println!(\"hello\");");
197 }
198
199 #[test]
200 fn test_extract_zeta2_current_region_v0120_git_merge_markers() {
201 let prompt = indoc::indoc! {"
202 <|file_sep|>src/main.rs
203 <|fim_prefix|>
204 fn main() {
205 <|fim_suffix|>
206 }
207 <|fim_middle|><<<<<<< CURRENT
208 println!(\"hello\");
209 =======
210 "};
211
212 let region =
213 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
214 assert_eq!(region, "println!(\"hello\");");
215 }
216
217 #[test]
218 fn test_extract_zeta2_current_region_v0120_with_cursor_marker() {
219 let prompt = indoc::indoc! {"
220 <|file_sep|>src/main.rs
221 <|fim_prefix|>
222 fn main() {
223 <|fim_suffix|>
224 }
225 <|fim_middle|><<<<<<< CURRENT
226 print<|user_cursor|>ln!(\"hello\");
227 =======
228 "};
229
230 let region =
231 extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
232 assert_eq!(region, "println!(\"hello\");");
233 }
234}