1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{Progress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result};
9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 args: &FormatPromptArgs,
20 app_state: Arc<EpAppState>,
21 cx: AsyncApp,
22) -> Result<()> {
23 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
24
25 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
26
27 let prompt_inputs = example
28 .prompt_inputs
29 .as_ref()
30 .context("prompt_inputs must be set after context retrieval")?;
31
32 let language = app_state
33 .languages
34 .load_language_for_file_path(&example.spec.cursor_path)
35 .await
36 .ok();
37 let snapshot_fut = cx.update(|cx| {
38 Buffer::build_snapshot(
39 prompt_inputs.content.as_str().into(),
40 language,
41 Some(app_state.languages.clone()),
42 cx,
43 )
44 });
45 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
46 let snapshot = cx.background_spawn(snapshot_fut).await;
47
48 match args.provider {
49 PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
50 step_progress.set_substatus("formatting teacher prompt");
51
52 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
53 cursor_point,
54 &snapshot,
55 edit_prediction::zeta2::max_editable_tokens(version),
56 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
57 );
58 let editable_range = editable_range.to_offset(&snapshot);
59 let context_range = context_range.to_offset(&snapshot);
60
61 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
62 example.prompt = Some(ExamplePrompt {
63 input: prompt,
64 expected_output: example
65 .spec
66 .expected_patches
67 .first()
68 .cloned()
69 .unwrap_or_default(),
70 provider: args.provider,
71 });
72 }
73 PredictionProvider::Zeta2(version) => {
74 step_progress.set_substatus("formatting zeta2 prompt");
75
76 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
77 cursor_point,
78 &snapshot,
79 edit_prediction::zeta2::max_editable_tokens(version),
80 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
81 );
82 let editable_range = editable_range.to_offset(&snapshot);
83 let context_range = context_range.to_offset(&snapshot);
84
85 let context_start = context_range.start;
86 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
87 let editable_range_in_excerpt =
88 (editable_range.start - context_start)..(editable_range.end - context_start);
89 let input = zeta_prompt::ZetaPromptInput {
90 cursor_path: example.spec.cursor_path.clone(),
91 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
92 editable_range_in_excerpt,
93 cursor_offset_in_excerpt,
94 events: prompt_inputs.edit_history.clone(),
95 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
96 };
97 let prompt = format_zeta_prompt(&input, version);
98 let expected_output = zeta2_output_for_patch(
99 &input,
100 &example
101 .spec
102 .expected_patches
103 .first()
104 .context("expected patches is empty")?
105 .clone(),
106 )?;
107 example.prompt = Some(ExamplePrompt {
108 input: prompt,
109 expected_output,
110 provider: args.provider,
111 });
112 }
113 _ => {
114 panic!("Cannot format prompt for {:?}", args.provider);
115 }
116 };
117 Ok(())
118}
119
120pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
121 let mut old_editable_region =
122 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
123
124 if !old_editable_region.ends_with_newline() {
125 old_editable_region.push('\n');
126 }
127
128 edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
129 format!(
130 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
131 patch, old_editable_region
132 )
133 })
134}
135
136pub struct TeacherPrompt;
137
138impl TeacherPrompt {
139 const PROMPT: &str = include_str!("teacher.prompt.md");
140 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
141 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
142 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
143
144 /// Truncate edit history to this number of last lines
145 const MAX_HISTORY_LINES: usize = 128;
146
147 pub fn format_prompt(
148 example: &Example,
149 editable_range: Range<usize>,
150 context_range: Range<usize>,
151 ) -> String {
152 let edit_history = Self::format_edit_history(&example.spec.edit_history);
153 let context = Self::format_context(example);
154 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
155
156 let prompt = Self::PROMPT
157 .replace("{{context}}", &context)
158 .replace("{{edit_history}}", &edit_history)
159 .replace("{{cursor_excerpt}}", &cursor_excerpt);
160
161 prompt
162 }
163
164 pub fn parse(example: &Example, response: &str) -> Result<String> {
165 // Extract updated (new) editable region from the model response.
166 // The model may include editable region markers in its output, so we need to strip them.
167 let new_editable_region = extract_last_codeblock(response);
168 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
169 let old_editable_region = Self::extract_editable_region(
170 &example
171 .prompt
172 .as_ref()
173 .context("example prompt missing")?
174 .input,
175 );
176 let prompt_inputs = example
177 .prompt_inputs
178 .as_ref()
179 .context("example is missing prompt inputs")?;
180
181 // Normalize leading newlines: if old starts with newline but new doesn't,
182 // prepend newline to new to preserve whitespace structure.
183 // This handles the case where the model drops the leading blank line.
184 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
185 new_editable_region.insert(0, '\n');
186 }
187
188 let (editable_region_offset, _) = prompt_inputs
189 .content
190 .match_indices(&old_editable_region)
191 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
192 .unwrap();
193 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
194 .matches('\n')
195 .count();
196
197 let diff = language::unified_diff_with_offsets(
198 &old_editable_region,
199 &new_editable_region,
200 editable_region_start_line as u32,
201 editable_region_start_line as u32,
202 );
203
204 let diff = indoc::formatdoc! {"
205 --- a/{path}
206 +++ b/{path}
207 {diff}",
208 path = example.spec.cursor_path.to_string_lossy(),
209 diff = diff,
210 };
211
212 Ok(diff)
213 }
214
215 fn format_edit_history(edit_history: &str) -> String {
216 // Strip comments ("garbage lines") from edit history
217 let lines = edit_history
218 .lines()
219 .filter(|&s| Self::is_udiff_content_line(s))
220 .collect::<Vec<_>>();
221
222 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
223 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
224 } else {
225 &lines
226 };
227
228 if history_lines.is_empty() {
229 return "(No edit history)".to_string();
230 }
231
232 history_lines.join("\n")
233 }
234
235 fn format_context(example: &Example) -> String {
236 let related_files = example
237 .prompt_inputs
238 .as_ref()
239 .and_then(|pi| pi.related_files.as_ref());
240
241 let Some(related_files) = related_files else {
242 return "(No context)".to_string();
243 };
244
245 if related_files.is_empty() {
246 return "(No context)".to_string();
247 }
248
249 let mut prompt = String::new();
250 for file in related_files {
251 let path_str = file.path.to_string_lossy();
252 writeln!(&mut prompt, "`````{path_str}").ok();
253 let mut prev_row = 0;
254 for excerpt in &file.excerpts {
255 if excerpt.row_range.start > prev_row {
256 prompt.push_str("…\n");
257 }
258 prompt.push_str(&excerpt.text);
259 prompt.push('\n');
260 prev_row = excerpt.row_range.end;
261 }
262 if prev_row < file.max_row {
263 prompt.push_str("…\n");
264 }
265 prompt.push_str("\n`````");
266 }
267
268 prompt
269 }
270
271 fn format_cursor_excerpt(
272 example: &Example,
273 editable_range: Range<usize>,
274 context_range: Range<usize>,
275 ) -> String {
276 let mut result = String::new();
277
278 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
279
280 let path_str = example.spec.cursor_path.to_string_lossy();
281 result.push_str(&format!("`````{path_str}\n"));
282 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
283 result.push_str(Self::EDITABLE_REGION_START);
284 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
285 result.push_str(Self::USER_CURSOR_MARKER);
286 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
287 result.push_str(Self::EDITABLE_REGION_END);
288 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
289 result.push_str("\n`````");
290
291 result
292 }
293
294 fn extract_editable_region(text: &str) -> String {
295 let start = text
296 .find(Self::EDITABLE_REGION_START)
297 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
298 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
299
300 let region = &text[start..end];
301 let region = region.strip_suffix('\n').unwrap_or(region);
302
303 region.replace(Self::USER_CURSOR_MARKER, "")
304 }
305
306 fn is_udiff_content_line(s: &str) -> bool {
307 s.starts_with("-")
308 || s.starts_with("+")
309 || s.starts_with(" ")
310 || s.starts_with("---")
311 || s.starts_with("+++")
312 || s.starts_with("@@")
313 }
314}
315
316fn extract_last_codeblock(text: &str) -> String {
317 let mut last_block = None;
318 let mut search_start = 0;
319
320 while let Some(start) = text[search_start..].find("```") {
321 let start = start + search_start;
322 let bytes = text.as_bytes();
323 let mut backtick_end = start;
324
325 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
326 backtick_end += 1;
327 }
328
329 let backtick_count = backtick_end - start;
330 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
331
332 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
333 backtick_end += 1;
334 }
335
336 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
337 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
338 last_block = Some(code_block.to_string());
339 search_start = backtick_end + end_pos + closing_pattern.len();
340 } else {
341 break;
342 }
343 }
344
345 last_block.unwrap_or_else(|| text.to_string())
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_extract_last_code_block() {
354 let text = indoc::indoc! {"
355 Some thinking
356
357 ```
358 first block
359 ```
360
361 `````path='something' lines=1:2
362 last block
363 `````
364 "};
365 let last_block = extract_last_codeblock(text);
366 assert_eq!(last_block, "last block\n");
367 }
368
369 #[test]
370 fn test_extract_codeblock_with_nested_fences() {
371 let text = indoc::indoc! {"
372 `````
373 content with ``` inline
374 and ```python nested
375 more content
376 `````
377 "};
378 let last_block = extract_last_codeblock(text);
379 assert_eq!(
380 last_block,
381 "content with ``` inline\nand ```python nested\nmore content\n"
382 );
383 }
384
385 #[test]
386 fn test_extract_codeblock_ignores_inline_backticks() {
387 let text = indoc::indoc! {"
388 `````
389 here is some `code` with inline backticks
390 and here```more```stuff
391 `````
392 "};
393 let last_block = extract_last_codeblock(text);
394 assert_eq!(
395 last_block,
396 "here is some `code` with inline backticks\nand here```more```stuff\n"
397 );
398 }
399
400 #[test]
401 fn test_extract_editable_region() {
402 let text = indoc::indoc! {"
403 some lines
404 are
405 here
406 <|editable_region_start|>
407 one
408 two three
409
410 <|editable_region_end|>
411 more
412 lines here
413 "};
414 let parsed = TeacherPrompt::extract_editable_region(text);
415 assert_eq!(
416 parsed,
417 indoc::indoc! {"
418 one
419 two three"}
420 );
421 }
422
423 #[test]
424 fn test_extract_last_codeblock_nested_bibtex() {
425 let text = indoc::indoc! {r#"
426 Looking at the edit history, I can see that a Citation section was just added.
427
428 `````
429 ## Collaborations
430 Our mission is to create a 4D generative model.
431
432 ## Citation
433
434 If you found Unique3D helpful, please cite our report:
435 ```bibtex
436 @misc{wu2024unique3d,
437 title={Unique3D},
438 }
439 ```
440 `````
441 "#};
442 let last_block = extract_last_codeblock(text);
443 assert_eq!(
444 last_block,
445 indoc::indoc! {r#"
446 ## Collaborations
447 Our mission is to create a 4D generative model.
448
449 ## Citation
450
451 If you found Unique3D helpful, please cite our report:
452 ```bibtex
453 @misc{wu2024unique3d,
454 title={Unique3D},
455 }
456 ```
457 "#}
458 );
459 }
460
461 #[test]
462 fn test_extract_editable_region_no_markers() {
463 let text = indoc::indoc! {"
464 one
465 two three"};
466 let parsed = TeacherPrompt::extract_editable_region(text);
467 assert_eq!(
468 parsed,
469 indoc::indoc! {"
470 one
471 two three"}
472 );
473 }
474
475 #[test]
476 fn test_extract_editable_region_strips_cursor_marker() {
477 let text = indoc::indoc! {"
478 <|editable_region_start|>
479 one
480 <|user_cursor|>two three
481
482 <|editable_region_end|>
483 "};
484 let parsed = TeacherPrompt::extract_editable_region(text);
485 assert_eq!(
486 parsed,
487 indoc::indoc! {"
488 one
489 two three"}
490 );
491 }
492}