1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{ActualCursor, Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{ExampleProgress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result, anyhow};
9use edit_prediction::{cursor_excerpt::editable_and_context_ranges_for_cursor_position, udiff};
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::ZetaFormat;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 args: &FormatPromptArgs,
21 app_state: Arc<EpAppState>,
22 example_progress: &ExampleProgress,
23 cx: AsyncApp,
24) -> Result<()> {
25 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
26
27 let step_progress = example_progress.start(Step::FormatPrompt);
28
29 let prompt_inputs = example
30 .prompt_inputs
31 .as_ref()
32 .context("prompt_inputs must be set after context retrieval")?;
33
34 let language = app_state
35 .languages
36 .load_language_for_file_path(&example.spec.cursor_path)
37 .await
38 .ok();
39 let snapshot_fut = cx.update(|cx| {
40 Buffer::build_snapshot(
41 prompt_inputs.content.as_str().into(),
42 language,
43 Some(app_state.languages.clone()),
44 cx,
45 )
46 });
47 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
48 let snapshot = cx.background_spawn(snapshot_fut).await;
49
50 match args.provider {
51 PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
52 step_progress.set_substatus("formatting teacher prompt");
53
54 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
55 cursor_point,
56 &snapshot,
57 edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()),
58 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
59 );
60 let editable_range = editable_range.to_offset(&snapshot);
61 let context_range = context_range.to_offset(&snapshot);
62
63 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
64 example.prompt = Some(ExamplePrompt {
65 input: prompt,
66 expected_output: String::new(),
67 rejected_output: None,
68 prefill: None,
69 provider: args.provider,
70 });
71 }
72 PredictionProvider::Zeta2(version) => {
73 step_progress.set_substatus("formatting zeta2 prompt");
74
75 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
76 cursor_point,
77 &snapshot,
78 edit_prediction::zeta2::max_editable_tokens(version),
79 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
80 );
81 let editable_range = editable_range.to_offset(&snapshot);
82 let context_range = context_range.to_offset(&snapshot);
83
84 let context_start = context_range.start;
85 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
86 let editable_range_in_excerpt =
87 (editable_range.start - context_start)..(editable_range.end - context_start);
88 let input = zeta_prompt::ZetaPromptInput {
89 cursor_path: example.spec.cursor_path.clone(),
90 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
91 editable_range_in_excerpt,
92 cursor_offset_in_excerpt,
93 excerpt_start_row: prompt_inputs.excerpt_start_row,
94 events: prompt_inputs.edit_history.clone(),
95 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
96 };
97 let prompt = format_zeta_prompt(&input, version);
98 let prefill = zeta_prompt::get_prefill(&input, version);
99 let (expected_patch, expected_cursor_offset) = example
100 .spec
101 .expected_patches_with_cursor_positions()
102 .into_iter()
103 .next()
104 .context("expected patches is empty")?;
105 let expected_output =
106 zeta2_output_for_patch(&input, &expected_patch, expected_cursor_offset, version)?;
107 let rejected_output = example
108 .spec
109 .rejected_patch
110 .as_ref()
111 .and_then(|patch| zeta2_output_for_patch(&input, patch, None, version).ok());
112
113 example.prompt = Some(ExamplePrompt {
114 input: prompt,
115 expected_output,
116 rejected_output,
117 provider: args.provider,
118 prefill: Some(prefill),
119 });
120 }
121 _ => {
122 panic!("Cannot format prompt for {:?}", args.provider);
123 }
124 };
125 Ok(())
126}
127
128pub fn zeta2_output_for_patch(
129 input: &zeta_prompt::ZetaPromptInput,
130 patch: &str,
131 cursor_offset: Option<usize>,
132 version: ZetaFormat,
133) -> Result<String> {
134 let mut old_editable_region =
135 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
136
137 if !old_editable_region.ends_with_newline() {
138 old_editable_region.push('\n');
139 }
140
141 let (mut result, first_hunk_offset) =
142 udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
143 || {
144 format!(
145 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
146 patch, old_editable_region
147 )
148 },
149 )?;
150
151 if let Some(cursor_offset) = cursor_offset {
152 // The cursor_offset is relative to the start of the hunk's new text (context + additions).
153 // We need to add where the hunk context matched in the editable region to compute
154 // the actual cursor position in the result.
155 let hunk_start = first_hunk_offset.unwrap_or(0);
156 let offset = (hunk_start + cursor_offset).min(result.len());
157 result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
158 }
159
160 match version {
161 ZetaFormat::V0120GitMergeMarkers
162 | ZetaFormat::V0131GitMergeMarkersPrefix
163 | ZetaFormat::V0211SeedCoder => {
164 if !result.ends_with('\n') {
165 result.push('\n');
166 }
167 result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
168 }
169 _ => (),
170 }
171
172 Ok(result)
173}
174
175pub struct TeacherPrompt;
176
177impl TeacherPrompt {
178 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
179 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
180 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
181 pub(crate) const NO_EDITS: &str = "NO_EDITS";
182
183 /// Truncate edit history to this number of last lines
184 const MAX_HISTORY_LINES: usize = 128;
185
186 pub fn format_prompt(
187 example: &Example,
188 editable_range: Range<usize>,
189 context_range: Range<usize>,
190 ) -> String {
191 let edit_history = Self::format_edit_history(&example.spec.edit_history);
192 let context = Self::format_context(example);
193 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
194
195 let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
196 let prompt = prompt_template
197 .replace("{{context}}", &context)
198 .replace("{{edit_history}}", &edit_history)
199 .replace("{{cursor_excerpt}}", &cursor_excerpt);
200
201 prompt
202 }
203
204 pub fn parse(example: &Example, response: &str) -> Result<(String, Option<ActualCursor>)> {
205 // Extract updated (new) editable region from the model response.
206 // The model may include editable region markers in its output, so we need to strip them.
207 let new_editable_region = extract_last_codeblock(response);
208
209 // Check if the model indicated no edits are needed
210 if new_editable_region.trim() == Self::NO_EDITS {
211 return Ok((String::new(), None));
212 }
213
214 let new_editable_region = Self::extract_editable_region(&new_editable_region)?;
215 let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER);
216 let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, "");
217 let old_editable_region = Self::extract_editable_region(
218 &example
219 .prompt
220 .as_ref()
221 .context("example prompt missing")?
222 .input,
223 )?
224 .replace(Self::USER_CURSOR_MARKER, "");
225
226 let prompt_inputs = example
227 .prompt_inputs
228 .as_ref()
229 .context("example is missing prompt inputs")?;
230
231 // Normalize leading newlines: if old starts with newline but new doesn't,
232 // prepend newline to new to preserve whitespace structure.
233 // This handles the case where the model drops the leading blank line.
234 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
235 new_editable_region.insert(0, '\n');
236 }
237
238 let (editable_region_offset, _) = prompt_inputs
239 .content
240 .match_indices(&old_editable_region)
241 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
242 .context("editable region not found in prompt content")?;
243 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
244 .matches('\n')
245 .count();
246
247 // Use full context so cursor offset (relative to editable region start) aligns with diff content
248 let editable_region_lines = old_editable_region.lines().count() as u32;
249 let diff = language::unified_diff_with_context(
250 &old_editable_region,
251 &new_editable_region,
252 editable_region_start_line as u32,
253 editable_region_start_line as u32,
254 editable_region_lines,
255 );
256
257 let diff = indoc::formatdoc! {"
258 --- a/{path}
259 +++ b/{path}
260 {diff}",
261 path = example.spec.cursor_path.to_string_lossy(),
262 diff = diff,
263 };
264
265 let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
266 ActualCursor::from_editable_region(
267 &example.spec.cursor_path,
268 editable_region_cursor_offset,
269 &new_editable_region,
270 &prompt_inputs.content,
271 editable_region_offset,
272 editable_region_start_line,
273 )
274 });
275
276 Ok((diff, actual_cursor))
277 }
278
279 fn format_edit_history(edit_history: &str) -> String {
280 // Strip comments ("garbage lines") from edit history
281 let lines = edit_history
282 .lines()
283 .filter(|&s| Self::is_udiff_content_line(s))
284 .collect::<Vec<_>>();
285
286 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
287 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
288 } else {
289 &lines
290 };
291
292 if history_lines.is_empty() {
293 return "(No edit history)".to_string();
294 }
295
296 history_lines.join("\n")
297 }
298
299 pub fn format_context(example: &Example) -> String {
300 let related_files = example
301 .prompt_inputs
302 .as_ref()
303 .and_then(|pi| pi.related_files.as_ref());
304
305 let Some(related_files) = related_files else {
306 return "(No context)".to_string();
307 };
308
309 if related_files.is_empty() {
310 return "(No context)".to_string();
311 }
312
313 let mut prompt = String::new();
314 for file in related_files {
315 let path_str = file.path.to_string_lossy();
316 writeln!(&mut prompt, "`````{path_str}").ok();
317
318 let mut prev_row = 0;
319 for excerpt in &file.excerpts {
320 if excerpt.row_range.start > prev_row {
321 prompt.push_str("…\n");
322 }
323 prompt.push_str(&excerpt.text);
324 prompt.push('\n');
325 prev_row = excerpt.row_range.end;
326 }
327 if prev_row < file.max_row {
328 prompt.push_str("…\n");
329 }
330 prompt.push_str("\n`````\n");
331 }
332
333 prompt
334 }
335
336 fn format_cursor_excerpt(
337 example: &Example,
338 editable_range: Range<usize>,
339 context_range: Range<usize>,
340 ) -> String {
341 let mut result = String::new();
342
343 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
344
345 let path_str = example.spec.cursor_path.to_string_lossy();
346 result.push_str(&format!("`````{path_str}\n"));
347 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
348 result.push_str(Self::EDITABLE_REGION_START);
349 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
350 result.push_str(Self::USER_CURSOR_MARKER);
351 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
352 result.push_str(Self::EDITABLE_REGION_END);
353 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
354 result.push_str("\n`````");
355
356 result
357 }
358
359 pub fn extract_editable_region(text: &str) -> Result<String> {
360 let start = text
361 .rfind(Self::EDITABLE_REGION_START)
362 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
363 let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
364
365 if start >= end {
366 return Err(anyhow!("Invalid editable region markers"));
367 }
368
369 let region = &text[start..end];
370 Ok(region.strip_suffix('\n').unwrap_or(region).to_string())
371 }
372
373 fn is_udiff_content_line(s: &str) -> bool {
374 s.starts_with("-")
375 || s.starts_with("+")
376 || s.starts_with(" ")
377 || s.starts_with("---")
378 || s.starts_with("+++")
379 || s.starts_with("@@")
380 }
381}
382
383/// Extract the cursor excerpt from an example.
384/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
385pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
386 // If we have the original prompt, extract the cursor excerpt from it
387 if let Some(prompt) = &example.prompt {
388 // Find "# 3. Current File" section and extract the content
389 if let Some(start) = prompt.input.find("# 3. Current File") {
390 let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
391 let backtick_count = prompt.input[content_start..]
392 .chars()
393 .take_while(|&c| c == '`')
394 .count();
395 let content_start = content_start + backtick_count;
396
397 // Find the path line and skip it
398 let newline_pos = prompt.input[content_start..].find('\n')?;
399 let text_start = content_start + newline_pos + 1;
400
401 // Find the closing backticks
402 let closing_pattern = "`".repeat(backtick_count);
403 let text_end = prompt.input[text_start..].find(&closing_pattern)?;
404 let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
405
406 let path_str = example.spec.cursor_path.to_string_lossy();
407 return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
408 }
409 }
410
411 // Fallback: construct from prompt_inputs if available
412 let prompt_inputs = example.prompt_inputs.as_ref()?;
413 let content = &prompt_inputs.content;
414 let cursor_offset = prompt_inputs.cursor_offset;
415
416 // Simple fallback: just show content around cursor with markers
417 let path_str = example.spec.cursor_path.to_string_lossy();
418 let mut result = format!("`````{path_str}\n");
419 result.push_str(TeacherPrompt::EDITABLE_REGION_START);
420 result.push_str(&content[..cursor_offset]);
421 result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
422 result.push_str(&content[cursor_offset..]);
423 result.push_str(TeacherPrompt::EDITABLE_REGION_END);
424 result.push_str("\n`````");
425
426 Some(result)
427}
428
429pub(crate) fn extract_last_codeblock(text: &str) -> String {
430 let mut last_block = None;
431 let mut search_start = 0;
432
433 while let Some(start) = text[search_start..].find("```") {
434 let start = start + search_start;
435 let bytes = text.as_bytes();
436 let mut backtick_end = start;
437
438 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
439 backtick_end += 1;
440 }
441
442 let backtick_count = backtick_end - start;
443 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
444
445 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
446 backtick_end += 1;
447 }
448
449 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
450 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
451 last_block = Some(code_block.to_string());
452 search_start = backtick_end + end_pos + closing_pattern.len();
453 } else {
454 break;
455 }
456 }
457
458 last_block.unwrap_or_else(|| text.to_string())
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_extract_last_code_block() {
467 let text = indoc::indoc! {"
468 Some thinking
469
470 ```
471 first block
472 ```
473
474 `````path='something' lines=1:2
475 last block
476 `````
477 "};
478 let last_block = extract_last_codeblock(text);
479 assert_eq!(last_block, "last block\n");
480 }
481
482 #[test]
483 fn test_extract_codeblock_with_nested_fences() {
484 let text = indoc::indoc! {"
485 `````
486 content with ``` inline
487 and ```python nested
488 more content
489 `````
490 "};
491 let last_block = extract_last_codeblock(text);
492 assert_eq!(
493 last_block,
494 "content with ``` inline\nand ```python nested\nmore content\n"
495 );
496 }
497
498 #[test]
499 fn test_extract_codeblock_ignores_inline_backticks() {
500 let text = indoc::indoc! {"
501 `````
502 here is some `code` with inline backticks
503 and here```more```stuff
504 `````
505 "};
506 let last_block = extract_last_codeblock(text);
507 assert_eq!(
508 last_block,
509 "here is some `code` with inline backticks\nand here```more```stuff\n"
510 );
511 }
512
513 #[test]
514 fn test_extract_editable_region() {
515 let text = indoc::indoc! {"
516 some lines
517 are
518 here
519 <|editable_region_start|>
520 one
521 two three
522
523 <|editable_region_end|>
524 more
525 lines here
526 "};
527 let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
528 assert_eq!(
529 parsed,
530 indoc::indoc! {"
531 one
532 two three"}
533 );
534 }
535
536 #[test]
537 fn test_extract_last_codeblock_nested_bibtex() {
538 let text = indoc::indoc! {r#"
539 Looking at the edit history, I can see that a Citation section was just added.
540
541 `````
542 ## Collaborations
543 Our mission is to create a 4D generative model.
544
545 ## Citation
546
547 If you found Unique3D helpful, please cite our report:
548 ```bibtex
549 @misc{wu2024unique3d,
550 title={Unique3D},
551 }
552 ```
553 `````
554 "#};
555 let last_block = extract_last_codeblock(text);
556 assert_eq!(
557 last_block,
558 indoc::indoc! {r#"
559 ## Collaborations
560 Our mission is to create a 4D generative model.
561
562 ## Citation
563
564 If you found Unique3D helpful, please cite our report:
565 ```bibtex
566 @misc{wu2024unique3d,
567 title={Unique3D},
568 }
569 ```
570 "#}
571 );
572 }
573
574 #[test]
575 fn test_extract_editable_region_no_markers() {
576 let text = indoc::indoc! {"
577 one
578 two three"};
579 let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
580 assert_eq!(
581 parsed,
582 indoc::indoc! {"
583 one
584 two three"}
585 );
586 }
587
588 #[test]
589 fn test_parse_no_edits_response() {
590 let response = indoc::indoc! {"
591 The code is already complete. There is no clear next edit to make.
592
593 `````
594 NO_EDITS
595 `````
596 "};
597 let codeblock = extract_last_codeblock(response);
598 assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
599 }
600}