1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{ExampleProgress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result};
9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 args: &FormatPromptArgs,
20 app_state: Arc<EpAppState>,
21 example_progress: &ExampleProgress,
22 cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
25
26 let step_progress = example_progress.start(Step::FormatPrompt);
27
28 let prompt_inputs = example
29 .prompt_inputs
30 .as_ref()
31 .context("prompt_inputs must be set after context retrieval")?;
32
33 let language = app_state
34 .languages
35 .load_language_for_file_path(&example.spec.cursor_path)
36 .await
37 .ok();
38 let snapshot_fut = cx.update(|cx| {
39 Buffer::build_snapshot(
40 prompt_inputs.content.as_str().into(),
41 language,
42 Some(app_state.languages.clone()),
43 cx,
44 )
45 });
46 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
47 let snapshot = cx.background_spawn(snapshot_fut).await;
48
49 match args.provider {
50 PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
51 step_progress.set_substatus("formatting teacher prompt");
52
53 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
54 cursor_point,
55 &snapshot,
56 edit_prediction::zeta2::max_editable_tokens(version),
57 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
58 );
59 let editable_range = editable_range.to_offset(&snapshot);
60 let context_range = context_range.to_offset(&snapshot);
61
62 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
63 example.prompt = Some(ExamplePrompt {
64 input: prompt,
65 expected_output: example
66 .spec
67 .expected_patches
68 .first()
69 .cloned()
70 .unwrap_or_default(),
71 provider: args.provider,
72 });
73 }
74 PredictionProvider::Zeta2(version) => {
75 step_progress.set_substatus("formatting zeta2 prompt");
76
77 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
78 cursor_point,
79 &snapshot,
80 edit_prediction::zeta2::max_editable_tokens(version),
81 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
82 );
83 let editable_range = editable_range.to_offset(&snapshot);
84 let context_range = context_range.to_offset(&snapshot);
85
86 let context_start = context_range.start;
87 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
88 let editable_range_in_excerpt =
89 (editable_range.start - context_start)..(editable_range.end - context_start);
90 let input = zeta_prompt::ZetaPromptInput {
91 cursor_path: example.spec.cursor_path.clone(),
92 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
93 editable_range_in_excerpt,
94 cursor_offset_in_excerpt,
95 events: prompt_inputs.edit_history.clone(),
96 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
97 };
98 let prompt = format_zeta_prompt(&input, version);
99 let expected_output = zeta2_output_for_patch(
100 &input,
101 &example
102 .spec
103 .expected_patches
104 .first()
105 .context("expected patches is empty")?
106 .clone(),
107 )?;
108 example.prompt = Some(ExamplePrompt {
109 input: prompt,
110 expected_output,
111 provider: args.provider,
112 });
113 }
114 _ => {
115 panic!("Cannot format prompt for {:?}", args.provider);
116 }
117 };
118 Ok(())
119}
120
121pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
122 let mut old_editable_region =
123 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
124
125 if !old_editable_region.ends_with_newline() {
126 old_editable_region.push('\n');
127 }
128
129 edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
130 format!(
131 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
132 patch, old_editable_region
133 )
134 })
135}
136
137pub struct TeacherPrompt;
138
139impl TeacherPrompt {
140 const PROMPT: &str = include_str!("teacher.prompt.md");
141 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
142 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
143 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
144
145 /// Truncate edit history to this number of last lines
146 const MAX_HISTORY_LINES: usize = 128;
147
148 pub fn format_prompt(
149 example: &Example,
150 editable_range: Range<usize>,
151 context_range: Range<usize>,
152 ) -> String {
153 let edit_history = Self::format_edit_history(&example.spec.edit_history);
154 let context = Self::format_context(example);
155 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
156
157 let prompt = Self::PROMPT
158 .replace("{{context}}", &context)
159 .replace("{{edit_history}}", &edit_history)
160 .replace("{{cursor_excerpt}}", &cursor_excerpt);
161
162 prompt
163 }
164
165 pub fn parse(example: &Example, response: &str) -> Result<String> {
166 // Extract updated (new) editable region from the model response.
167 // The model may include editable region markers in its output, so we need to strip them.
168 let new_editable_region = extract_last_codeblock(response);
169 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
170 let old_editable_region = Self::extract_editable_region(
171 &example
172 .prompt
173 .as_ref()
174 .context("example prompt missing")?
175 .input,
176 );
177 let prompt_inputs = example
178 .prompt_inputs
179 .as_ref()
180 .context("example is missing prompt inputs")?;
181
182 // Normalize leading newlines: if old starts with newline but new doesn't,
183 // prepend newline to new to preserve whitespace structure.
184 // This handles the case where the model drops the leading blank line.
185 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
186 new_editable_region.insert(0, '\n');
187 }
188
189 let (editable_region_offset, _) = prompt_inputs
190 .content
191 .match_indices(&old_editable_region)
192 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
193 .unwrap();
194 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
195 .matches('\n')
196 .count();
197
198 let diff = language::unified_diff_with_offsets(
199 &old_editable_region,
200 &new_editable_region,
201 editable_region_start_line as u32,
202 editable_region_start_line as u32,
203 );
204
205 let diff = indoc::formatdoc! {"
206 --- a/{path}
207 +++ b/{path}
208 {diff}",
209 path = example.spec.cursor_path.to_string_lossy(),
210 diff = diff,
211 };
212
213 Ok(diff)
214 }
215
216 fn format_edit_history(edit_history: &str) -> String {
217 // Strip comments ("garbage lines") from edit history
218 let lines = edit_history
219 .lines()
220 .filter(|&s| Self::is_udiff_content_line(s))
221 .collect::<Vec<_>>();
222
223 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
224 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
225 } else {
226 &lines
227 };
228
229 if history_lines.is_empty() {
230 return "(No edit history)".to_string();
231 }
232
233 history_lines.join("\n")
234 }
235
236 fn format_context(example: &Example) -> String {
237 let related_files = example
238 .prompt_inputs
239 .as_ref()
240 .and_then(|pi| pi.related_files.as_ref());
241
242 let Some(related_files) = related_files else {
243 return "(No context)".to_string();
244 };
245
246 if related_files.is_empty() {
247 return "(No context)".to_string();
248 }
249
250 let mut prompt = String::new();
251 for file in related_files {
252 let path_str = file.path.to_string_lossy();
253 writeln!(&mut prompt, "`````{path_str}").ok();
254 let mut prev_row = 0;
255 for excerpt in &file.excerpts {
256 if excerpt.row_range.start > prev_row {
257 prompt.push_str("…\n");
258 }
259 prompt.push_str(&excerpt.text);
260 prompt.push('\n');
261 prev_row = excerpt.row_range.end;
262 }
263 if prev_row < file.max_row {
264 prompt.push_str("…\n");
265 }
266 prompt.push_str("\n`````");
267 }
268
269 prompt
270 }
271
272 fn format_cursor_excerpt(
273 example: &Example,
274 editable_range: Range<usize>,
275 context_range: Range<usize>,
276 ) -> String {
277 let mut result = String::new();
278
279 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
280
281 let path_str = example.spec.cursor_path.to_string_lossy();
282 result.push_str(&format!("`````{path_str}\n"));
283 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
284 result.push_str(Self::EDITABLE_REGION_START);
285 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
286 result.push_str(Self::USER_CURSOR_MARKER);
287 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
288 result.push_str(Self::EDITABLE_REGION_END);
289 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
290 result.push_str("\n`````");
291
292 result
293 }
294
295 fn extract_editable_region(text: &str) -> String {
296 let start = text
297 .find(Self::EDITABLE_REGION_START)
298 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
299 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
300
301 let region = &text[start..end];
302 let region = region.strip_suffix('\n').unwrap_or(region);
303
304 region.replace(Self::USER_CURSOR_MARKER, "")
305 }
306
307 fn is_udiff_content_line(s: &str) -> bool {
308 s.starts_with("-")
309 || s.starts_with("+")
310 || s.starts_with(" ")
311 || s.starts_with("---")
312 || s.starts_with("+++")
313 || s.starts_with("@@")
314 }
315}
316
317fn extract_last_codeblock(text: &str) -> String {
318 let mut last_block = None;
319 let mut search_start = 0;
320
321 while let Some(start) = text[search_start..].find("```") {
322 let start = start + search_start;
323 let bytes = text.as_bytes();
324 let mut backtick_end = start;
325
326 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
327 backtick_end += 1;
328 }
329
330 let backtick_count = backtick_end - start;
331 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
332
333 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
334 backtick_end += 1;
335 }
336
337 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
338 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
339 last_block = Some(code_block.to_string());
340 search_start = backtick_end + end_pos + closing_pattern.len();
341 } else {
342 break;
343 }
344 }
345
346 last_block.unwrap_or_else(|| text.to_string())
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_extract_last_code_block() {
355 let text = indoc::indoc! {"
356 Some thinking
357
358 ```
359 first block
360 ```
361
362 `````path='something' lines=1:2
363 last block
364 `````
365 "};
366 let last_block = extract_last_codeblock(text);
367 assert_eq!(last_block, "last block\n");
368 }
369
370 #[test]
371 fn test_extract_codeblock_with_nested_fences() {
372 let text = indoc::indoc! {"
373 `````
374 content with ``` inline
375 and ```python nested
376 more content
377 `````
378 "};
379 let last_block = extract_last_codeblock(text);
380 assert_eq!(
381 last_block,
382 "content with ``` inline\nand ```python nested\nmore content\n"
383 );
384 }
385
386 #[test]
387 fn test_extract_codeblock_ignores_inline_backticks() {
388 let text = indoc::indoc! {"
389 `````
390 here is some `code` with inline backticks
391 and here```more```stuff
392 `````
393 "};
394 let last_block = extract_last_codeblock(text);
395 assert_eq!(
396 last_block,
397 "here is some `code` with inline backticks\nand here```more```stuff\n"
398 );
399 }
400
401 #[test]
402 fn test_extract_editable_region() {
403 let text = indoc::indoc! {"
404 some lines
405 are
406 here
407 <|editable_region_start|>
408 one
409 two three
410
411 <|editable_region_end|>
412 more
413 lines here
414 "};
415 let parsed = TeacherPrompt::extract_editable_region(text);
416 assert_eq!(
417 parsed,
418 indoc::indoc! {"
419 one
420 two three"}
421 );
422 }
423
424 #[test]
425 fn test_extract_last_codeblock_nested_bibtex() {
426 let text = indoc::indoc! {r#"
427 Looking at the edit history, I can see that a Citation section was just added.
428
429 `````
430 ## Collaborations
431 Our mission is to create a 4D generative model.
432
433 ## Citation
434
435 If you found Unique3D helpful, please cite our report:
436 ```bibtex
437 @misc{wu2024unique3d,
438 title={Unique3D},
439 }
440 ```
441 `````
442 "#};
443 let last_block = extract_last_codeblock(text);
444 assert_eq!(
445 last_block,
446 indoc::indoc! {r#"
447 ## Collaborations
448 Our mission is to create a 4D generative model.
449
450 ## Citation
451
452 If you found Unique3D helpful, please cite our report:
453 ```bibtex
454 @misc{wu2024unique3d,
455 title={Unique3D},
456 }
457 ```
458 "#}
459 );
460 }
461
462 #[test]
463 fn test_extract_editable_region_no_markers() {
464 let text = indoc::indoc! {"
465 one
466 two three"};
467 let parsed = TeacherPrompt::extract_editable_region(text);
468 assert_eq!(
469 parsed,
470 indoc::indoc! {"
471 one
472 two three"}
473 );
474 }
475
476 #[test]
477 fn test_extract_editable_region_strips_cursor_marker() {
478 let text = indoc::indoc! {"
479 <|editable_region_start|>
480 one
481 <|user_cursor|>two three
482
483 <|editable_region_end|>
484 "};
485 let parsed = TeacherPrompt::extract_editable_region(text);
486 assert_eq!(
487 parsed,
488 indoc::indoc! {"
489 one
490 two three"}
491 );
492 }
493}