1use crate::{
2 FormatPromptArgs, PredictionProvider,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 progress::{Progress, Step},
6 retrieve_context::run_context_retrieval,
7};
8use anyhow::{Context as _, Result};
9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
10use gpui::{AppContext, AsyncApp};
11use language::{Buffer, OffsetRangeExt, Point};
12use similar::DiffableStr;
13use std::sync::Arc;
14use std::{fmt::Write as _, ops::Range};
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 args: &FormatPromptArgs,
20 app_state: Arc<EpAppState>,
21 cx: AsyncApp,
22) -> Result<()> {
23 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
24
25 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
26
27 let prompt_inputs = example
28 .prompt_inputs
29 .as_ref()
30 .context("prompt_inputs must be set after context retrieval")?;
31
32 let language = app_state
33 .languages
34 .load_language_for_file_path(&example.spec.cursor_path)
35 .await
36 .ok();
37 let snapshot_fut = cx.update(|cx| {
38 Buffer::build_snapshot(
39 prompt_inputs.content.as_str().into(),
40 language,
41 Some(app_state.languages.clone()),
42 cx,
43 )
44 });
45 let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
46 let snapshot = cx.background_spawn(snapshot_fut).await;
47
48 let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
49 cursor_point,
50 &snapshot,
51 edit_prediction::zeta2::MAX_EDITABLE_TOKENS,
52 edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
53 );
54 let editable_range = editable_range.to_offset(&snapshot);
55 let context_range = context_range.to_offset(&snapshot);
56
57 match args.provider {
58 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
59 step_progress.set_substatus("formatting teacher prompt");
60 let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
61 example.prompt = Some(ExamplePrompt {
62 input: prompt,
63 expected_output: example
64 .spec
65 .expected_patches
66 .first()
67 .cloned()
68 .unwrap_or_default(),
69 provider: args.provider,
70 });
71 }
72 PredictionProvider::Zeta2(version) => {
73 step_progress.set_substatus("formatting zeta2 prompt");
74
75 let context_start = context_range.start;
76 let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
77 let editable_range_in_excerpt =
78 (editable_range.start - context_start)..(editable_range.end - context_start);
79 let input = zeta_prompt::ZetaPromptInput {
80 cursor_path: example.spec.cursor_path.clone(),
81 cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
82 editable_range_in_excerpt,
83 cursor_offset_in_excerpt,
84 events: prompt_inputs.edit_history.clone(),
85 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
86 };
87 let prompt = format_zeta_prompt(&input, version);
88 let expected_output = zeta2_output_for_patch(
89 &input,
90 &example
91 .spec
92 .expected_patches
93 .first()
94 .context("expected patches is empty")?
95 .clone(),
96 )?;
97 example.prompt = Some(ExamplePrompt {
98 input: prompt,
99 expected_output,
100 provider: args.provider,
101 });
102 }
103 _ => {
104 panic!("Cannot format prompt for {:?}", args.provider);
105 }
106 };
107 Ok(())
108}
109
110pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
111 let mut old_editable_region =
112 input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
113
114 if !old_editable_region.ends_with_newline() {
115 old_editable_region.push('\n');
116 }
117
118 edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
119 format!(
120 "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
121 patch, old_editable_region
122 )
123 })
124}
125
126pub struct TeacherPrompt;
127
128impl TeacherPrompt {
129 const PROMPT: &str = include_str!("teacher.prompt.md");
130 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
131 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
132 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
133
134 /// Truncate edit history to this number of last lines
135 const MAX_HISTORY_LINES: usize = 128;
136
137 pub fn format_prompt(
138 example: &Example,
139 editable_range: Range<usize>,
140 context_range: Range<usize>,
141 ) -> String {
142 let edit_history = Self::format_edit_history(&example.spec.edit_history);
143 let context = Self::format_context(example);
144 let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
145
146 let prompt = Self::PROMPT
147 .replace("{{context}}", &context)
148 .replace("{{edit_history}}", &edit_history)
149 .replace("{{cursor_excerpt}}", &cursor_excerpt);
150
151 prompt
152 }
153
154 pub fn parse(example: &Example, response: &str) -> Result<String> {
155 // Extract updated (new) editable region from the model response.
156 // The model may include editable region markers in its output, so we need to strip them.
157 let new_editable_region = extract_last_codeblock(response);
158 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
159 let old_editable_region = Self::extract_editable_region(
160 &example
161 .prompt
162 .as_ref()
163 .context("example prompt missing")?
164 .input,
165 );
166 let prompt_inputs = example
167 .prompt_inputs
168 .as_ref()
169 .context("example is missing prompt inputs")?;
170
171 // Normalize leading newlines: if old starts with newline but new doesn't,
172 // prepend newline to new to preserve whitespace structure.
173 // This handles the case where the model drops the leading blank line.
174 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
175 new_editable_region.insert(0, '\n');
176 }
177
178 let (editable_region_offset, _) = prompt_inputs
179 .content
180 .match_indices(&old_editable_region)
181 .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
182 .unwrap();
183 let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
184 .matches('\n')
185 .count();
186
187 let diff = language::unified_diff_with_offsets(
188 &old_editable_region,
189 &new_editable_region,
190 editable_region_start_line as u32,
191 editable_region_start_line as u32,
192 );
193
194 let diff = indoc::formatdoc! {"
195 --- a/{path}
196 +++ b/{path}
197 {diff}",
198 path = example.spec.cursor_path.to_string_lossy(),
199 diff = diff,
200 };
201
202 Ok(diff)
203 }
204
205 fn format_edit_history(edit_history: &str) -> String {
206 // Strip comments ("garbage lines") from edit history
207 let lines = edit_history
208 .lines()
209 .filter(|&s| Self::is_udiff_content_line(s))
210 .collect::<Vec<_>>();
211
212 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
213 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
214 } else {
215 &lines
216 };
217
218 if history_lines.is_empty() {
219 return "(No edit history)".to_string();
220 }
221
222 history_lines.join("\n")
223 }
224
225 fn format_context(example: &Example) -> String {
226 let related_files = example
227 .prompt_inputs
228 .as_ref()
229 .and_then(|pi| pi.related_files.as_ref());
230
231 let Some(related_files) = related_files else {
232 return "(No context)".to_string();
233 };
234
235 if related_files.is_empty() {
236 return "(No context)".to_string();
237 }
238
239 let mut prompt = String::new();
240 for file in related_files {
241 let path_str = file.path.to_string_lossy();
242 writeln!(&mut prompt, "`````{path_str}").ok();
243 let mut prev_row = 0;
244 for excerpt in &file.excerpts {
245 if excerpt.row_range.start > prev_row {
246 prompt.push_str("…\n");
247 }
248 prompt.push_str(&excerpt.text);
249 prompt.push('\n');
250 prev_row = excerpt.row_range.end;
251 }
252 if prev_row < file.max_row {
253 prompt.push_str("…\n");
254 }
255 prompt.push_str("\n`````");
256 }
257
258 prompt
259 }
260
261 fn format_cursor_excerpt(
262 example: &Example,
263 editable_range: Range<usize>,
264 context_range: Range<usize>,
265 ) -> String {
266 let mut result = String::new();
267
268 let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
269
270 let path_str = example.spec.cursor_path.to_string_lossy();
271 result.push_str(&format!("`````{path_str}\n"));
272 result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
273 result.push_str(Self::EDITABLE_REGION_START);
274 result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
275 result.push_str(Self::USER_CURSOR_MARKER);
276 result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
277 result.push_str(Self::EDITABLE_REGION_END);
278 result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
279 result.push_str("\n`````");
280
281 result
282 }
283
284 fn extract_editable_region(text: &str) -> String {
285 let start = text
286 .find(Self::EDITABLE_REGION_START)
287 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
288 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
289
290 let region = &text[start..end];
291 let region = region.strip_suffix('\n').unwrap_or(region);
292
293 region.replace(Self::USER_CURSOR_MARKER, "")
294 }
295
296 fn is_udiff_content_line(s: &str) -> bool {
297 s.starts_with("-")
298 || s.starts_with("+")
299 || s.starts_with(" ")
300 || s.starts_with("---")
301 || s.starts_with("+++")
302 || s.starts_with("@@")
303 }
304}
305
306fn extract_last_codeblock(text: &str) -> String {
307 let mut last_block = None;
308 let mut search_start = 0;
309
310 while let Some(start) = text[search_start..].find("```") {
311 let start = start + search_start;
312 let bytes = text.as_bytes();
313 let mut backtick_end = start;
314
315 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
316 backtick_end += 1;
317 }
318
319 let backtick_count = backtick_end - start;
320 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
321
322 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
323 backtick_end += 1;
324 }
325
326 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
327 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
328 last_block = Some(code_block.to_string());
329 search_start = backtick_end + end_pos + closing_pattern.len();
330 } else {
331 break;
332 }
333 }
334
335 last_block.unwrap_or_else(|| text.to_string())
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_extract_last_code_block() {
344 let text = indoc::indoc! {"
345 Some thinking
346
347 ```
348 first block
349 ```
350
351 `````path='something' lines=1:2
352 last block
353 `````
354 "};
355 let last_block = extract_last_codeblock(text);
356 assert_eq!(last_block, "last block\n");
357 }
358
359 #[test]
360 fn test_extract_codeblock_with_nested_fences() {
361 let text = indoc::indoc! {"
362 `````
363 content with ``` inline
364 and ```python nested
365 more content
366 `````
367 "};
368 let last_block = extract_last_codeblock(text);
369 assert_eq!(
370 last_block,
371 "content with ``` inline\nand ```python nested\nmore content\n"
372 );
373 }
374
375 #[test]
376 fn test_extract_codeblock_ignores_inline_backticks() {
377 let text = indoc::indoc! {"
378 `````
379 here is some `code` with inline backticks
380 and here```more```stuff
381 `````
382 "};
383 let last_block = extract_last_codeblock(text);
384 assert_eq!(
385 last_block,
386 "here is some `code` with inline backticks\nand here```more```stuff\n"
387 );
388 }
389
390 #[test]
391 fn test_extract_editable_region() {
392 let text = indoc::indoc! {"
393 some lines
394 are
395 here
396 <|editable_region_start|>
397 one
398 two three
399
400 <|editable_region_end|>
401 more
402 lines here
403 "};
404 let parsed = TeacherPrompt::extract_editable_region(text);
405 assert_eq!(
406 parsed,
407 indoc::indoc! {"
408 one
409 two three"}
410 );
411 }
412
413 #[test]
414 fn test_extract_last_codeblock_nested_bibtex() {
415 let text = indoc::indoc! {r#"
416 Looking at the edit history, I can see that a Citation section was just added.
417
418 `````
419 ## Collaborations
420 Our mission is to create a 4D generative model.
421
422 ## Citation
423
424 If you found Unique3D helpful, please cite our report:
425 ```bibtex
426 @misc{wu2024unique3d,
427 title={Unique3D},
428 }
429 ```
430 `````
431 "#};
432 let last_block = extract_last_codeblock(text);
433 assert_eq!(
434 last_block,
435 indoc::indoc! {r#"
436 ## Collaborations
437 Our mission is to create a 4D generative model.
438
439 ## Citation
440
441 If you found Unique3D helpful, please cite our report:
442 ```bibtex
443 @misc{wu2024unique3d,
444 title={Unique3D},
445 }
446 ```
447 "#}
448 );
449 }
450
451 #[test]
452 fn test_extract_editable_region_no_markers() {
453 let text = indoc::indoc! {"
454 one
455 two three"};
456 let parsed = TeacherPrompt::extract_editable_region(text);
457 assert_eq!(
458 parsed,
459 indoc::indoc! {"
460 one
461 two three"}
462 );
463 }
464
465 #[test]
466 fn test_extract_editable_region_strips_cursor_marker() {
467 let text = indoc::indoc! {"
468 <|editable_region_start|>
469 one
470 <|user_cursor|>two three
471
472 <|editable_region_end|>
473 "};
474 let parsed = TeacherPrompt::extract_editable_region(text);
475 assert_eq!(
476 parsed,
477 indoc::indoc! {"
478 one
479 two three"}
480 );
481 }
482}