1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{ExampleProgress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result};
9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 args: &FormatPromptArgs,
20 app_state: Arc<EpAppState>,
21 example_progress: &ExampleProgress,
22 cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
25
26 let step_progress = example_progress.start(Step::FormatPrompt);
27
28 let prompt_inputs = example
29 .prompt_inputs
30 .as_ref()
31 .context("prompt_inputs must be set after context retrieval")?;
32
33 let language = app_state
34 .languages
35 .load_language_for_file_path(&example.spec.cursor_path)
36 .await
37 .ok();
38 let snapshot_fut = cx.update(|cx| {
39 Buffer::build_snapshot(
40 prompt_inputs.content.as_str().into(),
41 language,
42 Some(app_state.languages.clone()),
43 cx,
44 )
45 });
46 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
47 let snapshot = cx.background_spawn(snapshot_fut).await;
48
49 match args.provider {
50 PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
51 step_progress.set_substatus("formatting teacher prompt");
52
53 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
54 cursor_point,
55 &snapshot,
56 edit_prediction::zeta2::max_editable_tokens(version),
57 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
58 );
59 let editable_range = editable_range.to_offset(&snapshot);
60 let context_range = context_range.to_offset(&snapshot);
61
62 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
63 example.prompt = Some(ExamplePrompt {
64 input: prompt,
65 expected_output: example
66 .spec
67 .expected_patches
68 .first()
69 .cloned()
70 .unwrap_or_default(),
71 provider: args.provider,
72 });
73 }
74 PredictionProvider::Zeta2(version) => {
75 step_progress.set_substatus("formatting zeta2 prompt");
76
77 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
78 cursor_point,
79 &snapshot,
80 edit_prediction::zeta2::max_editable_tokens(version),
81 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
82 );
83 let editable_range = editable_range.to_offset(&snapshot);
84 let context_range = context_range.to_offset(&snapshot);
85
86 let context_start = context_range.start;
87 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
88 let editable_range_in_excerpt =
89 (editable_range.start - context_start)..(editable_range.end - context_start);
90 let input = zeta_prompt::ZetaPromptInput {
91 cursor_path: example.spec.cursor_path.clone(),
92 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
93 editable_range_in_excerpt,
94 cursor_offset_in_excerpt,
95 events: prompt_inputs.edit_history.clone(),
96 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
97 };
98 let prompt = format_zeta_prompt(&input, version);
99 let expected_output = zeta2_output_for_patch(
100 &input,
101 &example
102 .spec
103 .expected_patches
104 .first()
105 .context("expected patches is empty")?
106 .clone(),
107 )?;
108 example.prompt = Some(ExamplePrompt {
109 input: prompt,
110 expected_output,
111 provider: args.provider,
112 });
113 }
114 _ => {
115 panic!("Cannot format prompt for {:?}", args.provider);
116 }
117 };
118 Ok(())
119}
120
121pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
122 let mut old_editable_region =
123 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
124
125 if !old_editable_region.ends_with_newline() {
126 old_editable_region.push('\n');
127 }
128
129 edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
130 format!(
131 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
132 patch, old_editable_region
133 )
134 })
135}
136
137pub struct TeacherPrompt;
138
139impl TeacherPrompt {
140 const PROMPT: &str = include_str!("teacher.prompt.md");
141 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
142 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
143 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
144
145 /// Truncate edit history to this number of last lines
146 const MAX_HISTORY_LINES: usize = 128;
147
148 pub fn format_prompt(
149 example: &Example,
150 editable_range: Range<usize>,
151 context_range: Range<usize>,
152 ) -> String {
153 let edit_history = Self::format_edit_history(&example.spec.edit_history);
154 let context = Self::format_context(example);
155 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
156
157 let prompt = Self::PROMPT
158 .replace("{{context}}", &context)
159 .replace("{{edit_history}}", &edit_history)
160 .replace("{{cursor_excerpt}}", &cursor_excerpt);
161
162 prompt
163 }
164
165 pub fn parse(example: &Example, response: &str) -> Result<String> {
166 // Extract updated (new) editable region from the model response.
167 // The model may include editable region markers in its output, so we need to strip them.
168 let new_editable_region = extract_last_codeblock(response);
169 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
170 let old_editable_region = Self::extract_editable_region(
171 &example
172 .prompt
173 .as_ref()
174 .context("example prompt missing")?
175 .input,
176 );
177 let prompt_inputs = example
178 .prompt_inputs
179 .as_ref()
180 .context("example is missing prompt inputs")?;
181
182 // Normalize leading newlines: if old starts with newline but new doesn't,
183 // prepend newline to new to preserve whitespace structure.
184 // This handles the case where the model drops the leading blank line.
185 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
186 new_editable_region.insert(0, '\n');
187 }
188
189 let (editable_region_offset, _) = prompt_inputs
190 .content
191 .match_indices(&old_editable_region)
192 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
193 .unwrap();
194 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
195 .matches('\n')
196 .count();
197
198 let diff = language::unified_diff_with_offsets(
199 &old_editable_region,
200 &new_editable_region,
201 editable_region_start_line as u32,
202 editable_region_start_line as u32,
203 );
204
205 let diff = indoc::formatdoc! {"
206 --- a/{path}
207 +++ b/{path}
208 {diff}",
209 path = example.spec.cursor_path.to_string_lossy(),
210 diff = diff,
211 };
212
213 Ok(diff)
214 }
215
216 fn format_edit_history(edit_history: &str) -> String {
217 // Strip comments ("garbage lines") from edit history
218 let lines = edit_history
219 .lines()
220 .filter(|&s| Self::is_udiff_content_line(s))
221 .collect::<Vec<_>>();
222
223 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
224 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
225 } else {
226 &lines
227 };
228
229 if history_lines.is_empty() {
230 return "(No edit history)".to_string();
231 }
232
233 history_lines.join("\n")
234 }
235
236 fn format_context(example: &Example) -> String {
237 let related_files = example
238 .prompt_inputs
239 .as_ref()
240 .and_then(|pi| pi.related_files.as_ref());
241
242 let Some(related_files) = related_files else {
243 return "(No context)".to_string();
244 };
245
246 if related_files.is_empty() {
247 return "(No context)".to_string();
248 }
249
250 let mut prompt = String::new();
251 for file in related_files {
252 let path_str = file.path.to_string_lossy();
253 writeln!(&mut prompt, "`````{path_str}").ok();
254
255 let mut prev_row = 0;
256 for excerpt in &file.excerpts {
257 if excerpt.row_range.start > prev_row {
258 prompt.push_str("…\n");
259 }
260 prompt.push_str(&excerpt.text);
261 prompt.push('\n');
262 prev_row = excerpt.row_range.end;
263 }
264 if prev_row < file.max_row {
265 prompt.push_str("…\n");
266 }
267 prompt.push_str("\n`````\n");
268 }
269
270 prompt
271 }
272
273 fn format_cursor_excerpt(
274 example: &Example,
275 editable_range: Range<usize>,
276 context_range: Range<usize>,
277 ) -> String {
278 let mut result = String::new();
279
280 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
281
282 let path_str = example.spec.cursor_path.to_string_lossy();
283 result.push_str(&format!("`````{path_str}\n"));
284 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
285 result.push_str(Self::EDITABLE_REGION_START);
286 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
287 result.push_str(Self::USER_CURSOR_MARKER);
288 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
289 result.push_str(Self::EDITABLE_REGION_END);
290 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
291 result.push_str("\n`````");
292
293 result
294 }
295
296 fn extract_editable_region(text: &str) -> String {
297 let start = text
298 .rfind(Self::EDITABLE_REGION_START)
299 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
300 let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
301
302 let region = &text[start..end];
303 let region = region.strip_suffix('\n').unwrap_or(region);
304
305 region.replace(Self::USER_CURSOR_MARKER, "")
306 }
307
308 fn is_udiff_content_line(s: &str) -> bool {
309 s.starts_with("-")
310 || s.starts_with("+")
311 || s.starts_with(" ")
312 || s.starts_with("---")
313 || s.starts_with("+++")
314 || s.starts_with("@@")
315 }
316}
317
318fn extract_last_codeblock(text: &str) -> String {
319 let mut last_block = None;
320 let mut search_start = 0;
321
322 while let Some(start) = text[search_start..].find("```") {
323 let start = start + search_start;
324 let bytes = text.as_bytes();
325 let mut backtick_end = start;
326
327 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
328 backtick_end += 1;
329 }
330
331 let backtick_count = backtick_end - start;
332 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
333
334 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
335 backtick_end += 1;
336 }
337
338 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
339 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
340 last_block = Some(code_block.to_string());
341 search_start = backtick_end + end_pos + closing_pattern.len();
342 } else {
343 break;
344 }
345 }
346
347 last_block.unwrap_or_else(|| text.to_string())
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_extract_last_code_block() {
356 let text = indoc::indoc! {"
357 Some thinking
358
359 ```
360 first block
361 ```
362
363 `````path='something' lines=1:2
364 last block
365 `````
366 "};
367 let last_block = extract_last_codeblock(text);
368 assert_eq!(last_block, "last block\n");
369 }
370
371 #[test]
372 fn test_extract_codeblock_with_nested_fences() {
373 let text = indoc::indoc! {"
374 `````
375 content with ``` inline
376 and ```python nested
377 more content
378 `````
379 "};
380 let last_block = extract_last_codeblock(text);
381 assert_eq!(
382 last_block,
383 "content with ``` inline\nand ```python nested\nmore content\n"
384 );
385 }
386
387 #[test]
388 fn test_extract_codeblock_ignores_inline_backticks() {
389 let text = indoc::indoc! {"
390 `````
391 here is some `code` with inline backticks
392 and here```more```stuff
393 `````
394 "};
395 let last_block = extract_last_codeblock(text);
396 assert_eq!(
397 last_block,
398 "here is some `code` with inline backticks\nand here```more```stuff\n"
399 );
400 }
401
402 #[test]
403 fn test_extract_editable_region() {
404 let text = indoc::indoc! {"
405 some lines
406 are
407 here
408 <|editable_region_start|>
409 one
410 two three
411
412 <|editable_region_end|>
413 more
414 lines here
415 "};
416 let parsed = TeacherPrompt::extract_editable_region(text);
417 assert_eq!(
418 parsed,
419 indoc::indoc! {"
420 one
421 two three"}
422 );
423 }
424
425 #[test]
426 fn test_extract_last_codeblock_nested_bibtex() {
427 let text = indoc::indoc! {r#"
428 Looking at the edit history, I can see that a Citation section was just added.
429
430 `````
431 ## Collaborations
432 Our mission is to create a 4D generative model.
433
434 ## Citation
435
436 If you found Unique3D helpful, please cite our report:
437 ```bibtex
438 @misc{wu2024unique3d,
439 title={Unique3D},
440 }
441 ```
442 `````
443 "#};
444 let last_block = extract_last_codeblock(text);
445 assert_eq!(
446 last_block,
447 indoc::indoc! {r#"
448 ## Collaborations
449 Our mission is to create a 4D generative model.
450
451 ## Citation
452
453 If you found Unique3D helpful, please cite our report:
454 ```bibtex
455 @misc{wu2024unique3d,
456 title={Unique3D},
457 }
458 ```
459 "#}
460 );
461 }
462
463 #[test]
464 fn test_extract_editable_region_no_markers() {
465 let text = indoc::indoc! {"
466 one
467 two three"};
468 let parsed = TeacherPrompt::extract_editable_region(text);
469 assert_eq!(
470 parsed,
471 indoc::indoc! {"
472 one
473 two three"}
474 );
475 }
476
477 #[test]
478 fn test_extract_editable_region_strips_cursor_marker() {
479 let text = indoc::indoc! {"
480 <|editable_region_start|>
481 one
482 <|user_cursor|>two three
483
484 <|editable_region_end|>
485 "};
486 let parsed = TeacherPrompt::extract_editable_region(text);
487 assert_eq!(
488 parsed,
489 indoc::indoc! {"
490 one
491 two three"}
492 );
493 }
494}