1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result};
10use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
11use gpui::{AsyncApp, Entity};
12use similar::DiffableStr;
13use std::fmt::Write as _;
14use std::sync::Arc;
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 prompt_format: PromptFormat,
20 app_state: Arc<EpAppState>,
21 mut cx: AsyncApp,
22) -> Result<()> {
23 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
24
25 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
26
27 match prompt_format {
28 PromptFormat::Teacher => {
29 step_progress.set_substatus("formatting teacher prompt");
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 expected_output: example
34 .spec
35 .expected_patches
36 .first()
37 .cloned()
38 .unwrap_or_default(),
39 format: prompt_format,
40 });
41 }
42 PromptFormat::Zeta2 => {
43 step_progress.set_substatus("loading project");
44 run_load_project(example, app_state, cx.clone()).await?;
45
46 step_progress.set_substatus("formatting zeta2 prompt");
47
48 let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
49 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
50 })?;
51
52 let state = example.state.as_ref().context("state must be set")?;
53 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
54 let project = state.project.clone();
55 let (_, input) =
56 ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
57 let events = ep_store
58 .edit_history_for_project(&project, cx)
59 .into_iter()
60 .map(|e| e.event)
61 .collect();
62 anyhow::Ok(zeta2_prompt_input(
63 &snapshot,
64 example
65 .context
66 .as_ref()
67 .context("context must be set")?
68 .files
69 .clone(),
70 events,
71 example.spec.cursor_path.clone(),
72 example
73 .buffer
74 .as_ref()
75 .context("buffer must be set")?
76 .cursor_offset,
77 ))
78 })?;
79 let prompt = format_zeta_prompt(&input);
80 let expected_output = zeta2_output_for_patch(
81 &input,
82 &example
83 .spec
84 .expected_patches
85 .first()
86 .context("expected patches is empty")?
87 .clone(),
88 )?;
89 example.prompt = Some(ExamplePrompt {
90 input: prompt,
91 expected_output,
92 format: prompt_format,
93 });
94 }
95 };
96 Ok(())
97}
98
99pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
100 let mut old_editable_region =
101 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
102
103 if !old_editable_region.ends_with_newline() {
104 old_editable_region.push('\n');
105 }
106
107 edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
108 format!(
109 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
110 patch, old_editable_region
111 )
112 })
113}
114
115pub struct TeacherPrompt;
116
117impl TeacherPrompt {
118 const PROMPT: &str = include_str!("teacher.prompt.md");
119 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
120 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
121 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
122
123 /// Truncate edit history to this number of last lines
124 const MAX_HISTORY_LINES: usize = 128;
125
126 pub fn format_prompt(example: &Example) -> String {
127 let edit_history = Self::format_edit_history(&example.spec.edit_history);
128 let context = Self::format_context(example);
129 let cursor_excerpt = Self::format_cursor_excerpt(example);
130
131 let prompt = Self::PROMPT
132 .replace("{{context}}", &context)
133 .replace("{{edit_history}}", &edit_history)
134 .replace("{{cursor_excerpt}}", &cursor_excerpt);
135
136 prompt
137 }
138
139 pub fn parse(example: &Example, response: &str) -> Result<String> {
140 // Ideally, we should always be able to find cursor position in the retrieved context.
141 // In reality, sometimes we don't find it for these reasons:
142 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
143 // (can be fixed by getting cursor coordinates at the load_example stage)
144 // 2. Context retriever just didn't include cursor line.
145 //
146 // In that case, fallback to using `cursor_position` as excerpt.
147 let example_buffer = example
148 .buffer
149 .as_ref()
150 .context("`buffer` should be filled in in the context collection step")?;
151
152 // Extract updated (new) editable region from the model response.
153 // The model may include editable region markers in its output, so we need to strip them.
154 let new_editable_region = extract_last_codeblock(response);
155 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
156
157 let old_editable_region =
158 example_buffer.content[example_buffer.editable_range.clone()].to_string();
159
160 // Normalize leading newlines: if old starts with newline but new doesn't,
161 // prepend newline to new to preserve whitespace structure.
162 // This handles the case where the model drops the leading blank line.
163 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
164 new_editable_region.insert(0, '\n');
165 }
166
167 let editable_region_start_line = example_buffer.content
168 [..example_buffer.editable_range.start]
169 .matches('\n')
170 .count();
171
172 let diff = language::unified_diff_with_offsets(
173 &old_editable_region,
174 &new_editable_region,
175 editable_region_start_line as u32,
176 editable_region_start_line as u32,
177 );
178
179 let diff = indoc::formatdoc! {"
180 --- a/{path}
181 +++ b/{path}
182 {diff}",
183 path = example.spec.cursor_path.to_string_lossy(),
184 diff = diff,
185 };
186
187 Ok(diff)
188 }
189
190 fn format_edit_history(edit_history: &str) -> String {
191 // Strip comments ("garbage lines") from edit history
192 let lines = edit_history
193 .lines()
194 .filter(|&s| Self::is_udiff_content_line(s))
195 .collect::<Vec<_>>();
196
197 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
198 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
199 } else {
200 &lines
201 };
202
203 if history_lines.is_empty() {
204 return "(No edit history)".to_string();
205 }
206
207 history_lines.join("\n")
208 }
209
210 fn format_context(example: &Example) -> String {
211 let context = example
212 .context
213 .as_ref()
214 .expect("Missing context retriever step");
215
216 if context.files.is_empty() {
217 return "(No context)".to_string();
218 }
219
220 let mut prompt = String::new();
221 for file in context.files.as_ref() {
222 let path_str = file.path.to_string_lossy();
223 writeln!(&mut prompt, "`````{path_str}").ok();
224 let mut prev_row = 0;
225 for excerpt in &file.excerpts {
226 if excerpt.row_range.start > prev_row {
227 prompt.push_str("…\n");
228 }
229 prompt.push_str(&excerpt.text);
230 prompt.push('\n');
231 prev_row = excerpt.row_range.end;
232 }
233 if prev_row < file.max_row {
234 prompt.push_str("…\n");
235 }
236 prompt.push_str("\n`````");
237 }
238
239 prompt
240 }
241
242 fn format_cursor_excerpt(example: &Example) -> String {
243 let mut result = String::new();
244
245 let example_buffer = example.buffer.as_ref().unwrap();
246
247 let path_str = example.spec.cursor_path.to_string_lossy();
248 result.push_str(&format!("`````{path_str}\n"));
249 result.push_str(
250 &example_buffer.content
251 [example_buffer.context_range.start..example_buffer.editable_range.start],
252 );
253 result.push_str(Self::EDITABLE_REGION_START);
254 result.push_str(
255 &example_buffer.content
256 [example_buffer.editable_range.start..example_buffer.cursor_offset],
257 );
258 result.push_str(Self::USER_CURSOR_MARKER);
259 result.push_str(
260 &example_buffer.content
261 [example_buffer.cursor_offset..example_buffer.editable_range.end],
262 );
263 result.push_str(Self::EDITABLE_REGION_END);
264 result.push_str(
265 &example_buffer.content
266 [example_buffer.editable_range.end..example_buffer.context_range.end],
267 );
268 result.push_str("\n`````");
269
270 result
271 }
272
273 fn extract_editable_region(text: &str) -> String {
274 let start = text
275 .find(Self::EDITABLE_REGION_START)
276 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
277 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
278
279 let region = &text[start..end];
280 let region = region.strip_suffix('\n').unwrap_or(region);
281
282 region.replace(Self::USER_CURSOR_MARKER, "")
283 }
284
285 fn is_udiff_content_line(s: &str) -> bool {
286 s.starts_with("-")
287 || s.starts_with("+")
288 || s.starts_with(" ")
289 || s.starts_with("---")
290 || s.starts_with("+++")
291 || s.starts_with("@@")
292 }
293}
294
295fn extract_last_codeblock(text: &str) -> String {
296 let mut last_block = None;
297 let mut search_start = 0;
298
299 while let Some(start) = text[search_start..].find("```") {
300 let start = start + search_start;
301 let bytes = text.as_bytes();
302 let mut backtick_end = start;
303
304 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
305 backtick_end += 1;
306 }
307
308 let backtick_count = backtick_end - start;
309 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
310
311 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
312 backtick_end += 1;
313 }
314
315 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
316 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
317 last_block = Some(code_block.to_string());
318 search_start = backtick_end + end_pos + closing_pattern.len();
319 } else {
320 break;
321 }
322 }
323
324 last_block.unwrap_or_else(|| text.to_string())
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_extract_last_code_block() {
333 let text = indoc::indoc! {"
334 Some thinking
335
336 ```
337 first block
338 ```
339
340 `````path='something' lines=1:2
341 last block
342 `````
343 "};
344 let last_block = extract_last_codeblock(text);
345 assert_eq!(last_block, "last block\n");
346 }
347
348 #[test]
349 fn test_extract_codeblock_with_nested_fences() {
350 let text = indoc::indoc! {"
351 `````
352 content with ``` inline
353 and ```python nested
354 more content
355 `````
356 "};
357 let last_block = extract_last_codeblock(text);
358 assert_eq!(
359 last_block,
360 "content with ``` inline\nand ```python nested\nmore content\n"
361 );
362 }
363
364 #[test]
365 fn test_extract_codeblock_ignores_inline_backticks() {
366 let text = indoc::indoc! {"
367 `````
368 here is some `code` with inline backticks
369 and here```more```stuff
370 `````
371 "};
372 let last_block = extract_last_codeblock(text);
373 assert_eq!(
374 last_block,
375 "here is some `code` with inline backticks\nand here```more```stuff\n"
376 );
377 }
378
379 #[test]
380 fn test_extract_editable_region() {
381 let text = indoc::indoc! {"
382 some lines
383 are
384 here
385 <|editable_region_start|>
386 one
387 two three
388
389 <|editable_region_end|>
390 more
391 lines here
392 "};
393 let parsed = TeacherPrompt::extract_editable_region(text);
394 assert_eq!(
395 parsed,
396 indoc::indoc! {"
397 one
398 two three"}
399 );
400 }
401
402 #[test]
403 fn test_extract_last_codeblock_nested_bibtex() {
404 let text = indoc::indoc! {r#"
405 Looking at the edit history, I can see that a Citation section was just added.
406
407 `````
408 ## Collaborations
409 Our mission is to create a 4D generative model.
410
411 ## Citation
412
413 If you found Unique3D helpful, please cite our report:
414 ```bibtex
415 @misc{wu2024unique3d,
416 title={Unique3D},
417 }
418 ```
419 `````
420 "#};
421 let last_block = extract_last_codeblock(text);
422 assert_eq!(
423 last_block,
424 indoc::indoc! {r#"
425 ## Collaborations
426 Our mission is to create a 4D generative model.
427
428 ## Citation
429
430 If you found Unique3D helpful, please cite our report:
431 ```bibtex
432 @misc{wu2024unique3d,
433 title={Unique3D},
434 }
435 ```
436 "#}
437 );
438 }
439
440 #[test]
441 fn test_extract_editable_region_no_markers() {
442 let text = indoc::indoc! {"
443 one
444 two three"};
445 let parsed = TeacherPrompt::extract_editable_region(text);
446 assert_eq!(
447 parsed,
448 indoc::indoc! {"
449 one
450 two three"}
451 );
452 }
453
454 #[test]
455 fn test_extract_editable_region_strips_cursor_marker() {
456 let text = indoc::indoc! {"
457 <|editable_region_start|>
458 one
459 <|user_cursor|>two three
460
461 <|editable_region_end|>
462 "};
463 let parsed = TeacherPrompt::extract_editable_region(text);
464 assert_eq!(
465 parsed,
466 indoc::indoc! {"
467 one
468 two three"}
469 );
470 }
471}