1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{ExampleProgress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result};
9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::ZetaVersion;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 args: &FormatPromptArgs,
21 app_state: Arc<EpAppState>,
22 example_progress: &ExampleProgress,
23 cx: AsyncApp,
24) -> Result<()> {
25 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
26
27 let step_progress = example_progress.start(Step::FormatPrompt);
28
29 let prompt_inputs = example
30 .prompt_inputs
31 .as_ref()
32 .context("prompt_inputs must be set after context retrieval")?;
33
34 let language = app_state
35 .languages
36 .load_language_for_file_path(&example.spec.cursor_path)
37 .await
38 .ok();
39 let snapshot_fut = cx.update(|cx| {
40 Buffer::build_snapshot(
41 prompt_inputs.content.as_str().into(),
42 language,
43 Some(app_state.languages.clone()),
44 None,
45 cx,
46 )
47 });
48 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
49 let snapshot = cx.background_spawn(snapshot_fut).await;
50
51 match args.provider {
52 PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
53 step_progress.set_substatus("formatting teacher prompt");
54
55 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
56 cursor_point,
57 &snapshot,
58 edit_prediction::zeta2::max_editable_tokens(version),
59 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
60 );
61 let editable_range = editable_range.to_offset(&snapshot);
62 let context_range = context_range.to_offset(&snapshot);
63
64 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
65 example.prompt = Some(ExamplePrompt {
66 input: prompt,
67 expected_output: example
68 .spec
69 .expected_patches
70 .first()
71 .cloned()
72 .unwrap_or_default(),
73 provider: args.provider,
74 });
75 }
76 PredictionProvider::Zeta2(version) => {
77 step_progress.set_substatus("formatting zeta2 prompt");
78
79 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
80 cursor_point,
81 &snapshot,
82 edit_prediction::zeta2::max_editable_tokens(version),
83 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
84 );
85 let editable_range = editable_range.to_offset(&snapshot);
86 let context_range = context_range.to_offset(&snapshot);
87
88 let context_start = context_range.start;
89 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
90 let editable_range_in_excerpt =
91 (editable_range.start - context_start)..(editable_range.end - context_start);
92 let input = zeta_prompt::ZetaPromptInput {
93 cursor_path: example.spec.cursor_path.clone(),
94 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
95 editable_range_in_excerpt,
96 cursor_offset_in_excerpt,
97 events: prompt_inputs.edit_history.clone(),
98 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
99 };
100 let prompt = format_zeta_prompt(&input, version);
101 let expected_output = zeta2_output_for_patch(
102 &input,
103 &example
104 .spec
105 .expected_patches
106 .first()
107 .context("expected patches is empty")?
108 .clone(),
109 version,
110 )?;
111 example.prompt = Some(ExamplePrompt {
112 input: prompt,
113 expected_output,
114 provider: args.provider,
115 });
116 }
117 _ => {
118 panic!("Cannot format prompt for {:?}", args.provider);
119 }
120 };
121 Ok(())
122}
123
124pub fn zeta2_output_for_patch(
125 input: &zeta_prompt::ZetaPromptInput,
126 patch: &str,
127 version: ZetaVersion,
128) -> Result<String> {
129 let mut old_editable_region =
130 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
131
132 if !old_editable_region.ends_with_newline() {
133 old_editable_region.push('\n');
134 }
135
136 let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
137 .with_context(|| {
138 format!(
139 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
140 patch, old_editable_region
141 )
142 })?;
143
144 if version == ZetaVersion::V0120GitMergeMarkers {
145 if !result.ends_with('\n') {
146 result.push('\n');
147 }
148 result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
149 }
150
151 Ok(result)
152}
153
154pub struct TeacherPrompt;
155
156impl TeacherPrompt {
157 const PROMPT: &str = include_str!("teacher.prompt.md");
158 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
159 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
160 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
161
162 /// Truncate edit history to this number of last lines
163 const MAX_HISTORY_LINES: usize = 128;
164
165 pub fn format_prompt(
166 example: &Example,
167 editable_range: Range<usize>,
168 context_range: Range<usize>,
169 ) -> String {
170 let edit_history = Self::format_edit_history(&example.spec.edit_history);
171 let context = Self::format_context(example);
172 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
173
174 let prompt = Self::PROMPT
175 .replace("{{context}}", &context)
176 .replace("{{edit_history}}", &edit_history)
177 .replace("{{cursor_excerpt}}", &cursor_excerpt);
178
179 prompt
180 }
181
182 pub fn parse(example: &Example, response: &str) -> Result<String> {
183 // Extract updated (new) editable region from the model response.
184 // The model may include editable region markers in its output, so we need to strip them.
185 let new_editable_region = extract_last_codeblock(response);
186 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
187 let old_editable_region = Self::extract_editable_region(
188 &example
189 .prompt
190 .as_ref()
191 .context("example prompt missing")?
192 .input,
193 );
194 let prompt_inputs = example
195 .prompt_inputs
196 .as_ref()
197 .context("example is missing prompt inputs")?;
198
199 // Normalize leading newlines: if old starts with newline but new doesn't,
200 // prepend newline to new to preserve whitespace structure.
201 // This handles the case where the model drops the leading blank line.
202 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
203 new_editable_region.insert(0, '\n');
204 }
205
206 let (editable_region_offset, _) = prompt_inputs
207 .content
208 .match_indices(&old_editable_region)
209 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
210 .unwrap();
211 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
212 .matches('\n')
213 .count();
214
215 let diff = language::unified_diff_with_offsets(
216 &old_editable_region,
217 &new_editable_region,
218 editable_region_start_line as u32,
219 editable_region_start_line as u32,
220 );
221
222 let diff = indoc::formatdoc! {"
223 --- a/{path}
224 +++ b/{path}
225 {diff}",
226 path = example.spec.cursor_path.to_string_lossy(),
227 diff = diff,
228 };
229
230 Ok(diff)
231 }
232
233 fn format_edit_history(edit_history: &str) -> String {
234 // Strip comments ("garbage lines") from edit history
235 let lines = edit_history
236 .lines()
237 .filter(|&s| Self::is_udiff_content_line(s))
238 .collect::<Vec<_>>();
239
240 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
241 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
242 } else {
243 &lines
244 };
245
246 if history_lines.is_empty() {
247 return "(No edit history)".to_string();
248 }
249
250 history_lines.join("\n")
251 }
252
253 fn format_context(example: &Example) -> String {
254 let related_files = example
255 .prompt_inputs
256 .as_ref()
257 .and_then(|pi| pi.related_files.as_ref());
258
259 let Some(related_files) = related_files else {
260 return "(No context)".to_string();
261 };
262
263 if related_files.is_empty() {
264 return "(No context)".to_string();
265 }
266
267 let mut prompt = String::new();
268 for file in related_files {
269 let path_str = file.path.to_string_lossy();
270 writeln!(&mut prompt, "`````{path_str}").ok();
271
272 let mut prev_row = 0;
273 for excerpt in &file.excerpts {
274 if excerpt.row_range.start > prev_row {
275 prompt.push_str("…\n");
276 }
277 prompt.push_str(&excerpt.text);
278 prompt.push('\n');
279 prev_row = excerpt.row_range.end;
280 }
281 if prev_row < file.max_row {
282 prompt.push_str("…\n");
283 }
284 prompt.push_str("\n`````\n");
285 }
286
287 prompt
288 }
289
290 fn format_cursor_excerpt(
291 example: &Example,
292 editable_range: Range<usize>,
293 context_range: Range<usize>,
294 ) -> String {
295 let mut result = String::new();
296
297 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
298
299 let path_str = example.spec.cursor_path.to_string_lossy();
300 result.push_str(&format!("`````{path_str}\n"));
301 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
302 result.push_str(Self::EDITABLE_REGION_START);
303 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
304 result.push_str(Self::USER_CURSOR_MARKER);
305 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
306 result.push_str(Self::EDITABLE_REGION_END);
307 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
308 result.push_str("\n`````");
309
310 result
311 }
312
313 fn extract_editable_region(text: &str) -> String {
314 let start = text
315 .rfind(Self::EDITABLE_REGION_START)
316 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
317 let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
318
319 let region = &text[start..end];
320 let region = region.strip_suffix('\n').unwrap_or(region);
321
322 region.replace(Self::USER_CURSOR_MARKER, "")
323 }
324
325 fn is_udiff_content_line(s: &str) -> bool {
326 s.starts_with("-")
327 || s.starts_with("+")
328 || s.starts_with(" ")
329 || s.starts_with("---")
330 || s.starts_with("+++")
331 || s.starts_with("@@")
332 }
333}
334
335fn extract_last_codeblock(text: &str) -> String {
336 let mut last_block = None;
337 let mut search_start = 0;
338
339 while let Some(start) = text[search_start..].find("```") {
340 let start = start + search_start;
341 let bytes = text.as_bytes();
342 let mut backtick_end = start;
343
344 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
345 backtick_end += 1;
346 }
347
348 let backtick_count = backtick_end - start;
349 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
350
351 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
352 backtick_end += 1;
353 }
354
355 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
356 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
357 last_block = Some(code_block.to_string());
358 search_start = backtick_end + end_pos + closing_pattern.len();
359 } else {
360 break;
361 }
362 }
363
364 last_block.unwrap_or_else(|| text.to_string())
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_extract_last_code_block() {
373 let text = indoc::indoc! {"
374 Some thinking
375
376 ```
377 first block
378 ```
379
380 `````path='something' lines=1:2
381 last block
382 `````
383 "};
384 let last_block = extract_last_codeblock(text);
385 assert_eq!(last_block, "last block\n");
386 }
387
388 #[test]
389 fn test_extract_codeblock_with_nested_fences() {
390 let text = indoc::indoc! {"
391 `````
392 content with ``` inline
393 and ```python nested
394 more content
395 `````
396 "};
397 let last_block = extract_last_codeblock(text);
398 assert_eq!(
399 last_block,
400 "content with ``` inline\nand ```python nested\nmore content\n"
401 );
402 }
403
404 #[test]
405 fn test_extract_codeblock_ignores_inline_backticks() {
406 let text = indoc::indoc! {"
407 `````
408 here is some `code` with inline backticks
409 and here```more```stuff
410 `````
411 "};
412 let last_block = extract_last_codeblock(text);
413 assert_eq!(
414 last_block,
415 "here is some `code` with inline backticks\nand here```more```stuff\n"
416 );
417 }
418
419 #[test]
420 fn test_extract_editable_region() {
421 let text = indoc::indoc! {"
422 some lines
423 are
424 here
425 <|editable_region_start|>
426 one
427 two three
428
429 <|editable_region_end|>
430 more
431 lines here
432 "};
433 let parsed = TeacherPrompt::extract_editable_region(text);
434 assert_eq!(
435 parsed,
436 indoc::indoc! {"
437 one
438 two three"}
439 );
440 }
441
442 #[test]
443 fn test_extract_last_codeblock_nested_bibtex() {
444 let text = indoc::indoc! {r#"
445 Looking at the edit history, I can see that a Citation section was just added.
446
447 `````
448 ## Collaborations
449 Our mission is to create a 4D generative model.
450
451 ## Citation
452
453 If you found Unique3D helpful, please cite our report:
454 ```bibtex
455 @misc{wu2024unique3d,
456 title={Unique3D},
457 }
458 ```
459 `````
460 "#};
461 let last_block = extract_last_codeblock(text);
462 assert_eq!(
463 last_block,
464 indoc::indoc! {r#"
465 ## Collaborations
466 Our mission is to create a 4D generative model.
467
468 ## Citation
469
470 If you found Unique3D helpful, please cite our report:
471 ```bibtex
472 @misc{wu2024unique3d,
473 title={Unique3D},
474 }
475 ```
476 "#}
477 );
478 }
479
480 #[test]
481 fn test_extract_editable_region_no_markers() {
482 let text = indoc::indoc! {"
483 one
484 two three"};
485 let parsed = TeacherPrompt::extract_editable_region(text);
486 assert_eq!(
487 parsed,
488 indoc::indoc! {"
489 one
490 two three"}
491 );
492 }
493
494 #[test]
495 fn test_extract_editable_region_strips_cursor_marker() {
496 let text = indoc::indoc! {"
497 <|editable_region_start|>
498 one
499 <|user_cursor|>two three
500
501 <|editable_region_end|>
502 "};
503 let parsed = TeacherPrompt::extract_editable_region(text);
504 assert_eq!(
505 parsed,
506 indoc::indoc! {"
507 one
508 two three"}
509 );
510 }
511}