1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::AsyncApp;
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 step_progress.set_substatus("formatting teacher prompt");
31 let prompt = TeacherPrompt::format_prompt(example);
32 example.prompt = Some(ExamplePrompt {
33 input: prompt,
34 expected_output: example
35 .spec
36 .expected_patches
37 .first()
38 .cloned()
39 .unwrap_or_default(),
40 format: prompt_format,
41 });
42 }
43 PromptFormat::Zeta2 => {
44 step_progress.set_substatus("loading project");
45 run_load_project(example, app_state, cx.clone()).await?;
46
47 step_progress.set_substatus("formatting zeta2 prompt");
48
49 let ep_store = cx.update(|cx| {
50 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
51 })??;
52
53 let state = example.state.as_ref().context("state must be set")?;
54 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
55 let project = state.project.clone();
56 let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
57 let events = ep_store
58 .edit_history_for_project(&project, cx)
59 .into_iter()
60 .map(|e| e.event)
61 .collect();
62 anyhow::Ok(zeta2_prompt_input(
63 &snapshot,
64 example
65 .context
66 .as_ref()
67 .context("context must be set")?
68 .files
69 .clone(),
70 events,
71 example.spec.cursor_path.clone(),
72 example
73 .buffer
74 .as_ref()
75 .context("buffer must be set")?
76 .cursor_offset,
77 ))
78 })??;
79 let prompt = format_zeta_prompt(&input);
80 let expected_output = zeta2_output_for_patch(
81 &input,
82 &example
83 .spec
84 .expected_patches
85 .first()
86 .context("expected patches is empty")?
87 .clone(),
88 )?;
89 example.prompt = Some(ExamplePrompt {
90 input: prompt,
91 expected_output,
92 format: prompt_format,
93 });
94 }
95 };
96 Ok(())
97}
98
99pub struct TeacherPrompt;
100
101impl TeacherPrompt {
102 const PROMPT: &str = include_str!("teacher.prompt.md");
103 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
104 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
105 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
106
107 /// Truncate edit history to this number of last lines
108 const MAX_HISTORY_LINES: usize = 128;
109
110 pub fn format_prompt(example: &Example) -> String {
111 let edit_history = Self::format_edit_history(&example.spec.edit_history);
112 let context = Self::format_context(example);
113 let editable_region = Self::format_editable_region(example);
114
115 let prompt = Self::PROMPT
116 .replace("{{context}}", &context)
117 .replace("{{edit_history}}", &edit_history)
118 .replace("{{editable_region}}", &editable_region);
119
120 prompt
121 }
122
123 pub fn parse(example: &Example, response: &str) -> Result<String> {
124 // Ideally, we should always be able to find cursor position in the retrieved context.
125 // In reality, sometimes we don't find it for these reasons:
126 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
127 // (can be fixed by getting cursor coordinates at the load_example stage)
128 // 2. Context retriever just didn't include cursor line.
129 //
130 // In that case, fallback to using `cursor_position` as excerpt.
131 let example_buffer = example
132 .buffer
133 .as_ref()
134 .context("`buffer` should be filled in in the context collection step")?;
135 let cursor_file = &example_buffer.content;
136
137 // Extract updated (new) editable region from the model response.
138 // The model may include editable region markers in its output, so we need to strip them.
139 let new_editable_region = extract_last_codeblock(response);
140 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
141
142 let old_editable_region =
143 example_buffer.content[example_buffer.editable_range.clone()].to_string();
144
145 // Normalize leading newlines: if old starts with newline but new doesn't,
146 // prepend newline to new to preserve whitespace structure.
147 // This handles the case where the model drops the leading blank line.
148 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
149 new_editable_region.insert(0, '\n');
150 }
151
152 ensure!(
153 cursor_file.contains(&old_editable_region),
154 "Something's wrong: editable_region is not found in the cursor file"
155 );
156
157 // Apply editable region to a larger context and compute diff.
158 // This is needed to get a better context lines around the editable region
159 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
160 let diff = language::unified_diff(&cursor_file, &edited_file);
161
162 let diff = indoc::formatdoc! {"
163 --- a/{path}
164 +++ b/{path}
165 {diff}",
166 path = example.spec.cursor_path.to_string_lossy(),
167 diff = diff,
168 };
169
170 Ok(diff)
171 }
172
173 fn format_edit_history(edit_history: &str) -> String {
174 // Strip comments ("garbage lines") from edit history
175 let lines = edit_history
176 .lines()
177 .filter(|&s| Self::is_udiff_content_line(s))
178 .collect::<Vec<_>>();
179
180 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
181 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
182 } else {
183 &lines
184 };
185
186 if history_lines.is_empty() {
187 return "(No edit history)".to_string();
188 }
189
190 history_lines.join("\n")
191 }
192
193 fn format_context(example: &Example) -> String {
194 assert!(example.context.is_some(), "Missing context retriever step");
195
196 let mut prompt = String::new();
197 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
198
199 prompt
200 }
201
202 fn format_editable_region(example: &Example) -> String {
203 let mut result = String::new();
204
205 let example_buffer = example.buffer.as_ref().unwrap();
206
207 let path_str = example.spec.cursor_path.to_string_lossy();
208 result.push_str(&format!("`````path=\"{path_str}\"\n"));
209 result.push_str(
210 &example_buffer.content
211 [example_buffer.context_range.start..example_buffer.editable_range.start],
212 );
213 result.push_str(Self::EDITABLE_REGION_START);
214 result.push_str(
215 &example_buffer.content
216 [example_buffer.editable_range.start..example_buffer.cursor_offset],
217 );
218 result.push_str(Self::USER_CURSOR_MARKER);
219 result.push_str(
220 &example_buffer.content
221 [example_buffer.cursor_offset..example_buffer.editable_range.end],
222 );
223 result.push_str(Self::EDITABLE_REGION_END);
224 result.push_str(
225 &example_buffer.content
226 [example_buffer.editable_range.end..example_buffer.context_range.end],
227 );
228 result.push_str("\n`````");
229
230 result
231 }
232
233 fn extract_editable_region(text: &str) -> String {
234 let start = text
235 .find(Self::EDITABLE_REGION_START)
236 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
237 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
238
239 let region = &text[start..end];
240 let region = region.strip_suffix('\n').unwrap_or(region);
241
242 region.replace("<|user_cursor|>", "")
243 }
244
245 fn is_udiff_content_line(s: &str) -> bool {
246 s.starts_with("-")
247 || s.starts_with("+")
248 || s.starts_with(" ")
249 || s.starts_with("---")
250 || s.starts_with("+++")
251 || s.starts_with("@@")
252 }
253}
254
255fn extract_last_codeblock(text: &str) -> String {
256 let mut last_block = None;
257 let mut search_start = 0;
258
259 while let Some(start) = text[search_start..].find("```") {
260 let start = start + search_start;
261 let bytes = text.as_bytes();
262 let mut backtick_end = start;
263
264 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
265 backtick_end += 1;
266 }
267
268 let backtick_count = backtick_end - start;
269 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
270
271 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
272 backtick_end += 1;
273 }
274
275 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
276 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
277 last_block = Some(code_block.to_string());
278 search_start = backtick_end + end_pos + closing_pattern.len();
279 } else {
280 break;
281 }
282 }
283
284 last_block.unwrap_or_else(|| text.to_string())
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_extract_last_code_block() {
293 let text = indoc::indoc! {"
294 Some thinking
295
296 ```
297 first block
298 ```
299
300 `````path='something' lines=1:2
301 last block
302 `````
303 "};
304 let last_block = extract_last_codeblock(text);
305 assert_eq!(last_block, "last block\n");
306 }
307
308 #[test]
309 fn test_extract_codeblock_with_nested_fences() {
310 let text = indoc::indoc! {"
311 `````
312 content with ``` inline
313 and ```python nested
314 more content
315 `````
316 "};
317 let last_block = extract_last_codeblock(text);
318 assert_eq!(
319 last_block,
320 "content with ``` inline\nand ```python nested\nmore content\n"
321 );
322 }
323
324 #[test]
325 fn test_extract_codeblock_ignores_inline_backticks() {
326 let text = indoc::indoc! {"
327 `````
328 here is some `code` with inline backticks
329 and here```more```stuff
330 `````
331 "};
332 let last_block = extract_last_codeblock(text);
333 assert_eq!(
334 last_block,
335 "here is some `code` with inline backticks\nand here```more```stuff\n"
336 );
337 }
338
339 #[test]
340 fn test_extract_editable_region() {
341 let text = indoc::indoc! {"
342 some lines
343 are
344 here
345 <|editable_region_start|>
346 one
347 two three
348
349 <|editable_region_end|>
350 more
351 lines here
352 "};
353 let parsed = TeacherPrompt::extract_editable_region(text);
354 assert_eq!(
355 parsed,
356 indoc::indoc! {"
357 one
358 two three
359 "}
360 );
361 }
362
363 #[test]
364 fn test_extract_last_codeblock_nested_bibtex() {
365 let text = indoc::indoc! {r#"
366 Looking at the edit history, I can see that a Citation section was just added.
367
368 `````
369 ## Collaborations
370 Our mission is to create a 4D generative model.
371
372 ## Citation
373
374 If you found Unique3D helpful, please cite our report:
375 ```bibtex
376 @misc{wu2024unique3d,
377 title={Unique3D},
378 }
379 ```
380 `````
381 "#};
382 let last_block = extract_last_codeblock(text);
383 assert_eq!(
384 last_block,
385 indoc::indoc! {r#"
386 ## Collaborations
387 Our mission is to create a 4D generative model.
388
389 ## Citation
390
391 If you found Unique3D helpful, please cite our report:
392 ```bibtex
393 @misc{wu2024unique3d,
394 title={Unique3D},
395 }
396 ```
397 "#}
398 );
399 }
400
401 #[test]
402 fn test_extract_editable_region_no_markers() {
403 let text = indoc::indoc! {"
404 one
405 two three
406 "};
407 let parsed = TeacherPrompt::extract_editable_region(text);
408 assert_eq!(
409 parsed,
410 indoc::indoc! {"
411 one
412 two three"}
413 );
414 }
415
416 #[test]
417 fn test_extract_editable_region_strips_cursor_marker() {
418 let text = indoc::indoc! {"
419 <|editable_region_start|>
420 one
421 <|user_cursor|>two three
422
423 <|editable_region_end|>
424 "};
425 let parsed = TeacherPrompt::extract_editable_region(text);
426 assert_eq!(
427 parsed,
428 indoc::indoc! {"
429 one
430 two three
431 "}
432 );
433 }
434}