1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::{AsyncApp, Entity};
15use std::fmt::Write as _;
16use std::sync::Arc;
17use zeta_prompt::format_zeta_prompt;
18
19pub async fn run_format_prompt(
20 example: &mut Example,
21 prompt_format: PromptFormat,
22 app_state: Arc<EpAppState>,
23 mut cx: AsyncApp,
24) -> Result<()> {
25 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
26
27 let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
28
29 match prompt_format {
30 PromptFormat::Teacher => {
31 step_progress.set_substatus("formatting teacher prompt");
32 let prompt = TeacherPrompt::format_prompt(example);
33 example.prompt = Some(ExamplePrompt {
34 input: prompt,
35 expected_output: example
36 .spec
37 .expected_patches
38 .first()
39 .cloned()
40 .unwrap_or_default(),
41 format: prompt_format,
42 });
43 }
44 PromptFormat::Zeta2 => {
45 step_progress.set_substatus("loading project");
46 run_load_project(example, app_state, cx.clone()).await?;
47
48 step_progress.set_substatus("formatting zeta2 prompt");
49
50 let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
51 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
52 })?;
53
54 let state = example.state.as_ref().context("state must be set")?;
55 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
56 let project = state.project.clone();
57 let (_, input) =
58 ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
59 let events = ep_store
60 .edit_history_for_project(&project, cx)
61 .into_iter()
62 .map(|e| e.event)
63 .collect();
64 anyhow::Ok(zeta2_prompt_input(
65 &snapshot,
66 example
67 .context
68 .as_ref()
69 .context("context must be set")?
70 .files
71 .clone(),
72 events,
73 example.spec.cursor_path.clone(),
74 example
75 .buffer
76 .as_ref()
77 .context("buffer must be set")?
78 .cursor_offset,
79 ))
80 })?;
81 let prompt = format_zeta_prompt(&input);
82 let expected_output = zeta2_output_for_patch(
83 &input,
84 &example
85 .spec
86 .expected_patches
87 .first()
88 .context("expected patches is empty")?
89 .clone(),
90 )?;
91 example.prompt = Some(ExamplePrompt {
92 input: prompt,
93 expected_output,
94 format: prompt_format,
95 });
96 }
97 };
98 Ok(())
99}
100
101pub struct TeacherPrompt;
102
103impl TeacherPrompt {
104 const PROMPT: &str = include_str!("teacher.prompt.md");
105 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
106 pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
107 pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
108
109 /// Truncate edit history to this number of last lines
110 const MAX_HISTORY_LINES: usize = 128;
111
112 pub fn format_prompt(example: &Example) -> String {
113 let edit_history = Self::format_edit_history(&example.spec.edit_history);
114 let context = Self::format_context(example);
115 let cursor_excerpt = Self::format_cursor_excerpt(example);
116
117 let prompt = Self::PROMPT
118 .replace("{{context}}", &context)
119 .replace("{{edit_history}}", &edit_history)
120 .replace("{{cursor_excerpt}}", &cursor_excerpt);
121
122 prompt
123 }
124
125 pub fn parse(example: &Example, response: &str) -> Result<String> {
126 // Ideally, we should always be able to find cursor position in the retrieved context.
127 // In reality, sometimes we don't find it for these reasons:
128 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
129 // (can be fixed by getting cursor coordinates at the load_example stage)
130 // 2. Context retriever just didn't include cursor line.
131 //
132 // In that case, fallback to using `cursor_position` as excerpt.
133 let example_buffer = example
134 .buffer
135 .as_ref()
136 .context("`buffer` should be filled in in the context collection step")?;
137
138 // Extract updated (new) editable region from the model response.
139 // The model may include editable region markers in its output, so we need to strip them.
140 let new_editable_region = extract_last_codeblock(response);
141 let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
142
143 let old_editable_region =
144 example_buffer.content[example_buffer.editable_range.clone()].to_string();
145
146 // Normalize leading newlines: if old starts with newline but new doesn't,
147 // prepend newline to new to preserve whitespace structure.
148 // This handles the case where the model drops the leading blank line.
149 if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
150 new_editable_region.insert(0, '\n');
151 }
152
153 let editable_region_start_line = example_buffer.content
154 [..example_buffer.editable_range.start]
155 .matches('\n')
156 .count();
157
158 let diff = language::unified_diff_with_offsets(
159 &old_editable_region,
160 &new_editable_region,
161 editable_region_start_line as u32,
162 editable_region_start_line as u32,
163 );
164
165 let diff = indoc::formatdoc! {"
166 --- a/{path}
167 +++ b/{path}
168 {diff}",
169 path = example.spec.cursor_path.to_string_lossy(),
170 diff = diff,
171 };
172
173 Ok(diff)
174 }
175
176 fn format_edit_history(edit_history: &str) -> String {
177 // Strip comments ("garbage lines") from edit history
178 let lines = edit_history
179 .lines()
180 .filter(|&s| Self::is_udiff_content_line(s))
181 .collect::<Vec<_>>();
182
183 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
184 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
185 } else {
186 &lines
187 };
188
189 if history_lines.is_empty() {
190 return "(No edit history)".to_string();
191 }
192
193 history_lines.join("\n")
194 }
195
196 fn format_context(example: &Example) -> String {
197 let context = example
198 .context
199 .as_ref()
200 .expect("Missing context retriever step");
201
202 if context.files.is_empty() {
203 return "(No context)".to_string();
204 }
205
206 let mut prompt = String::new();
207 for file in context.files.as_ref() {
208 let path_str = file.path.to_string_lossy();
209 writeln!(&mut prompt, "`````{path_str}").ok();
210 let mut prev_row = 0;
211 for excerpt in &file.excerpts {
212 if excerpt.row_range.start > prev_row {
213 prompt.push_str("…\n");
214 }
215 prompt.push_str(&excerpt.text);
216 prompt.push('\n');
217 prev_row = excerpt.row_range.end;
218 }
219 if prev_row < file.max_row {
220 prompt.push_str("…\n");
221 }
222 prompt.push_str("\n`````");
223 }
224
225 prompt
226 }
227
228 fn format_cursor_excerpt(example: &Example) -> String {
229 let mut result = String::new();
230
231 let example_buffer = example.buffer.as_ref().unwrap();
232
233 let path_str = example.spec.cursor_path.to_string_lossy();
234 result.push_str(&format!("`````{path_str}\n"));
235 result.push_str(
236 &example_buffer.content
237 [example_buffer.context_range.start..example_buffer.editable_range.start],
238 );
239 result.push_str(Self::EDITABLE_REGION_START);
240 result.push_str(
241 &example_buffer.content
242 [example_buffer.editable_range.start..example_buffer.cursor_offset],
243 );
244 result.push_str(Self::USER_CURSOR_MARKER);
245 result.push_str(
246 &example_buffer.content
247 [example_buffer.cursor_offset..example_buffer.editable_range.end],
248 );
249 result.push_str(Self::EDITABLE_REGION_END);
250 result.push_str(
251 &example_buffer.content
252 [example_buffer.editable_range.end..example_buffer.context_range.end],
253 );
254 result.push_str("\n`````");
255
256 result
257 }
258
259 fn extract_editable_region(text: &str) -> String {
260 let start = text
261 .find(Self::EDITABLE_REGION_START)
262 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
263 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
264
265 let region = &text[start..end];
266 let region = region.strip_suffix('\n').unwrap_or(region);
267
268 region.replace(Self::USER_CURSOR_MARKER, "")
269 }
270
271 fn is_udiff_content_line(s: &str) -> bool {
272 s.starts_with("-")
273 || s.starts_with("+")
274 || s.starts_with(" ")
275 || s.starts_with("---")
276 || s.starts_with("+++")
277 || s.starts_with("@@")
278 }
279}
280
281fn extract_last_codeblock(text: &str) -> String {
282 let mut last_block = None;
283 let mut search_start = 0;
284
285 while let Some(start) = text[search_start..].find("```") {
286 let start = start + search_start;
287 let bytes = text.as_bytes();
288 let mut backtick_end = start;
289
290 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
291 backtick_end += 1;
292 }
293
294 let backtick_count = backtick_end - start;
295 let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
296
297 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
298 backtick_end += 1;
299 }
300
301 if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
302 let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
303 last_block = Some(code_block.to_string());
304 search_start = backtick_end + end_pos + closing_pattern.len();
305 } else {
306 break;
307 }
308 }
309
310 last_block.unwrap_or_else(|| text.to_string())
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_extract_last_code_block() {
319 let text = indoc::indoc! {"
320 Some thinking
321
322 ```
323 first block
324 ```
325
326 `````path='something' lines=1:2
327 last block
328 `````
329 "};
330 let last_block = extract_last_codeblock(text);
331 assert_eq!(last_block, "last block\n");
332 }
333
334 #[test]
335 fn test_extract_codeblock_with_nested_fences() {
336 let text = indoc::indoc! {"
337 `````
338 content with ``` inline
339 and ```python nested
340 more content
341 `````
342 "};
343 let last_block = extract_last_codeblock(text);
344 assert_eq!(
345 last_block,
346 "content with ``` inline\nand ```python nested\nmore content\n"
347 );
348 }
349
350 #[test]
351 fn test_extract_codeblock_ignores_inline_backticks() {
352 let text = indoc::indoc! {"
353 `````
354 here is some `code` with inline backticks
355 and here```more```stuff
356 `````
357 "};
358 let last_block = extract_last_codeblock(text);
359 assert_eq!(
360 last_block,
361 "here is some `code` with inline backticks\nand here```more```stuff\n"
362 );
363 }
364
365 #[test]
366 fn test_extract_editable_region() {
367 let text = indoc::indoc! {"
368 some lines
369 are
370 here
371 <|editable_region_start|>
372 one
373 two three
374
375 <|editable_region_end|>
376 more
377 lines here
378 "};
379 let parsed = TeacherPrompt::extract_editable_region(text);
380 assert_eq!(
381 parsed,
382 indoc::indoc! {"
383 one
384 two three"}
385 );
386 }
387
388 #[test]
389 fn test_extract_last_codeblock_nested_bibtex() {
390 let text = indoc::indoc! {r#"
391 Looking at the edit history, I can see that a Citation section was just added.
392
393 `````
394 ## Collaborations
395 Our mission is to create a 4D generative model.
396
397 ## Citation
398
399 If you found Unique3D helpful, please cite our report:
400 ```bibtex
401 @misc{wu2024unique3d,
402 title={Unique3D},
403 }
404 ```
405 `````
406 "#};
407 let last_block = extract_last_codeblock(text);
408 assert_eq!(
409 last_block,
410 indoc::indoc! {r#"
411 ## Collaborations
412 Our mission is to create a 4D generative model.
413
414 ## Citation
415
416 If you found Unique3D helpful, please cite our report:
417 ```bibtex
418 @misc{wu2024unique3d,
419 title={Unique3D},
420 }
421 ```
422 "#}
423 );
424 }
425
426 #[test]
427 fn test_extract_editable_region_no_markers() {
428 let text = indoc::indoc! {"
429 one
430 two three"};
431 let parsed = TeacherPrompt::extract_editable_region(text);
432 assert_eq!(
433 parsed,
434 indoc::indoc! {"
435 one
436 two three"}
437 );
438 }
439
440 #[test]
441 fn test_extract_editable_region_strips_cursor_marker() {
442 let text = indoc::indoc! {"
443 <|editable_region_start|>
444 one
445 <|user_cursor|>two three
446
447 <|editable_region_end|>
448 "};
449 let parsed = TeacherPrompt::extract_editable_region(text);
450 assert_eq!(
451 parsed,
452 indoc::indoc! {"
453 one
454 two three"}
455 );
456 }
457}