1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::{AsyncApp, Entity};
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 step_progress.set_substatus("formatting teacher prompt");
31 let prompt = TeacherPrompt::format_prompt(example);
32 example.prompt = Some(ExamplePrompt {
33 input: prompt,
34 expected_output: example
35 .spec
36 .expected_patches
37 .first()
38 .cloned()
39 .unwrap_or_default(),
40 format: prompt_format,
41 });
42 }
43 PromptFormat::Zeta2 => {
44 step_progress.set_substatus("loading project");
45 run_load_project(example, app_state, cx.clone()).await?;
46
47 step_progress.set_substatus("formatting zeta2 prompt");
48
49 let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
50 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
51 })?;
52
53 let state = example.state.as_ref().context("state must be set")?;
54 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
55 let project = state.project.clone();
56 let (_, input) =
57 ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
58 let events = ep_store
59 .edit_history_for_project(&project, cx)
60 .into_iter()
61 .map(|e| e.event)
62 .collect();
63 anyhow::Ok(zeta2_prompt_input(
64 &snapshot,
65 example
66 .context
67 .as_ref()
68 .context("context must be set")?
69 .files
70 .clone(),
71 events,
72 example.spec.cursor_path.clone(),
73 example
74 .buffer
75 .as_ref()
76 .context("buffer must be set")?
77 .cursor_offset,
78 ))
79 })?;
80 let prompt = format_zeta_prompt(&input);
81 let expected_output = zeta2_output_for_patch(
82 &input,
83 &example
84 .spec
85 .expected_patches
86 .first()
87 .context("expected patches is empty")?
88 .clone(),
89 )?;
90 example.prompt = Some(ExamplePrompt {
91 input: prompt,
92 expected_output,
93 format: prompt_format,
94 });
95 }
96 };
97 Ok(())
98}
99
100pub struct TeacherPrompt;
101
102impl TeacherPrompt {
103 const PROMPT: &str = include_str!("teacher.prompt.md");
104 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
105 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
106 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
107
108 /// Truncate edit history to this number of last lines
109 const MAX_HISTORY_LINES: usize = 128;
110
111 pub fn format_prompt(example: &Example) -> String {
112 let edit_history = Self::format_edit_history(&example.spec.edit_history);
113 let context = Self::format_context(example);
114 let editable_region = Self::format_editable_region(example);
115
116 let prompt = Self::PROMPT
117 .replace("{{context}}", &context)
118 .replace("{{edit_history}}", &edit_history)
119 .replace("{{editable_region}}", &editable_region);
120
121 prompt
122 }
123
124 pub fn parse(example: &Example, response: &str) -> Result<String> {
125 // Ideally, we should always be able to find cursor position in the retrieved context.
126 // In reality, sometimes we don't find it for these reasons:
127 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
128 // (can be fixed by getting cursor coordinates at the load_example stage)
129 // 2. Context retriever just didn't include cursor line.
130 //
131 // In that case, fallback to using `cursor_position` as excerpt.
132 let example_buffer = example
133 .buffer
134 .as_ref()
135 .context("`buffer` should be filled in in the context collection step")?;
136 let cursor_file = &example_buffer.content;
137
138 // Extract updated (new) editable region from the model response.
139 // The model may include editable region markers in its output, so we need to strip them.
140 let new_editable_region = extract_last_codeblock(response);
141 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
142
143 let old_editable_region =
144 example_buffer.content[example_buffer.editable_range.clone()].to_string();
145
146 // Normalize leading newlines: if old starts with newline but new doesn't,
147 // prepend newline to new to preserve whitespace structure.
148 // This handles the case where the model drops the leading blank line.
149 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
150 new_editable_region.insert(0, '\n');
151 }
152
153 ensure!(
154 cursor_file.contains(&old_editable_region),
155 "Something's wrong: editable_region is not found in the cursor file"
156 );
157
158 // Apply editable region to a larger context and compute diff.
159 // This is needed to get a better context lines around the editable region
160 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
161 let diff = language::unified_diff(&cursor_file, &edited_file);
162
163 let diff = indoc::formatdoc! {"
164 --- a/{path}
165 +++ b/{path}
166 {diff}",
167 path = example.spec.cursor_path.to_string_lossy(),
168 diff = diff,
169 };
170
171 Ok(diff)
172 }
173
174 fn format_edit_history(edit_history: &str) -> String {
175 // Strip comments ("garbage lines") from edit history
176 let lines = edit_history
177 .lines()
178 .filter(|&s| Self::is_udiff_content_line(s))
179 .collect::<Vec<_>>();
180
181 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
182 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
183 } else {
184 &lines
185 };
186
187 if history_lines.is_empty() {
188 return "(No edit history)".to_string();
189 }
190
191 history_lines.join("\n")
192 }
193
194 fn format_context(example: &Example) -> String {
195 assert!(example.context.is_some(), "Missing context retriever step");
196
197 let mut prompt = String::new();
198 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
199
200 prompt
201 }
202
203 fn format_editable_region(example: &Example) -> String {
204 let mut result = String::new();
205
206 let example_buffer = example.buffer.as_ref().unwrap();
207
208 let path_str = example.spec.cursor_path.to_string_lossy();
209 result.push_str(&format!("`````path=\"{path_str}\"\n"));
210 result.push_str(
211 &example_buffer.content
212 [example_buffer.context_range.start..example_buffer.editable_range.start],
213 );
214 result.push_str(Self::EDITABLE_REGION_START);
215 result.push_str(
216 &example_buffer.content
217 [example_buffer.editable_range.start..example_buffer.cursor_offset],
218 );
219 result.push_str(Self::USER_CURSOR_MARKER);
220 result.push_str(
221 &example_buffer.content
222 [example_buffer.cursor_offset..example_buffer.editable_range.end],
223 );
224 result.push_str(Self::EDITABLE_REGION_END);
225 result.push_str(
226 &example_buffer.content
227 [example_buffer.editable_range.end..example_buffer.context_range.end],
228 );
229 result.push_str("\n`````");
230
231 result
232 }
233
234 fn extract_editable_region(text: &str) -> String {
235 let start = text
236 .find(Self::EDITABLE_REGION_START)
237 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
238 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
239
240 let region = &text[start..end];
241 let region = region.strip_suffix('\n').unwrap_or(region);
242
243 region.replace("<|user_cursor|>", "")
244 }
245
246 fn is_udiff_content_line(s: &str) -> bool {
247 s.starts_with("-")
248 || s.starts_with("+")
249 || s.starts_with(" ")
250 || s.starts_with("---")
251 || s.starts_with("+++")
252 || s.starts_with("@@")
253 }
254}
255
256fn extract_last_codeblock(text: &str) -> String {
257 let mut last_block = None;
258 let mut search_start = 0;
259
260 while let Some(start) = text[search_start..].find("```") {
261 let start = start + search_start;
262 let bytes = text.as_bytes();
263 let mut backtick_end = start;
264
265 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
266 backtick_end += 1;
267 }
268
269 let backtick_count = backtick_end - start;
270 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
271
272 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
273 backtick_end += 1;
274 }
275
276 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
277 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
278 last_block = Some(code_block.to_string());
279 search_start = backtick_end + end_pos + closing_pattern.len();
280 } else {
281 break;
282 }
283 }
284
285 last_block.unwrap_or_else(|| text.to_string())
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_extract_last_code_block() {
294 let text = indoc::indoc! {"
295 Some thinking
296
297 ```
298 first block
299 ```
300
301 `````path='something' lines=1:2
302 last block
303 `````
304 "};
305 let last_block = extract_last_codeblock(text);
306 assert_eq!(last_block, "last block\n");
307 }
308
309 #[test]
310 fn test_extract_codeblock_with_nested_fences() {
311 let text = indoc::indoc! {"
312 `````
313 content with ``` inline
314 and ```python nested
315 more content
316 `````
317 "};
318 let last_block = extract_last_codeblock(text);
319 assert_eq!(
320 last_block,
321 "content with ``` inline\nand ```python nested\nmore content\n"
322 );
323 }
324
325 #[test]
326 fn test_extract_codeblock_ignores_inline_backticks() {
327 let text = indoc::indoc! {"
328 `````
329 here is some `code` with inline backticks
330 and here```more```stuff
331 `````
332 "};
333 let last_block = extract_last_codeblock(text);
334 assert_eq!(
335 last_block,
336 "here is some `code` with inline backticks\nand here```more```stuff\n"
337 );
338 }
339
340 #[test]
341 fn test_extract_editable_region() {
342 let text = indoc::indoc! {"
343 some lines
344 are
345 here
346 <|editable_region_start|>
347 one
348 two three
349
350 <|editable_region_end|>
351 more
352 lines here
353 "};
354 let parsed = TeacherPrompt::extract_editable_region(text);
355 assert_eq!(
356 parsed,
357 indoc::indoc! {"
358 one
359 two three
360 "}
361 );
362 }
363
364 #[test]
365 fn test_extract_last_codeblock_nested_bibtex() {
366 let text = indoc::indoc! {r#"
367 Looking at the edit history, I can see that a Citation section was just added.
368
369 `````
370 ## Collaborations
371 Our mission is to create a 4D generative model.
372
373 ## Citation
374
375 If you found Unique3D helpful, please cite our report:
376 ```bibtex
377 @misc{wu2024unique3d,
378 title={Unique3D},
379 }
380 ```
381 `````
382 "#};
383 let last_block = extract_last_codeblock(text);
384 assert_eq!(
385 last_block,
386 indoc::indoc! {r#"
387 ## Collaborations
388 Our mission is to create a 4D generative model.
389
390 ## Citation
391
392 If you found Unique3D helpful, please cite our report:
393 ```bibtex
394 @misc{wu2024unique3d,
395 title={Unique3D},
396 }
397 ```
398 "#}
399 );
400 }
401
402 #[test]
403 fn test_extract_editable_region_no_markers() {
404 let text = indoc::indoc! {"
405 one
406 two three
407 "};
408 let parsed = TeacherPrompt::extract_editable_region(text);
409 assert_eq!(
410 parsed,
411 indoc::indoc! {"
412 one
413 two three"}
414 );
415 }
416
417 #[test]
418 fn test_extract_editable_region_strips_cursor_marker() {
419 let text = indoc::indoc! {"
420 <|editable_region_start|>
421 one
422 <|user_cursor|>two three
423
424 <|editable_region_end|>
425 "};
426 let parsed = TeacherPrompt::extract_editable_region(text);
427 assert_eq!(
428 parsed,
429 indoc::indoc! {"
430 one
431 two three
432 "}
433 );
434 }
435}