format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result, ensure};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::{AsyncApp, Entity};
 15use std::sync::Arc;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    prompt_format: PromptFormat,
 21    app_state: Arc<EpAppState>,
 22    mut cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 25
 26    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            step_progress.set_substatus("formatting teacher prompt");
 31            let prompt = TeacherPrompt::format_prompt(example);
 32            example.prompt = Some(ExamplePrompt {
 33                input: prompt,
 34                expected_output: example
 35                    .spec
 36                    .expected_patches
 37                    .first()
 38                    .cloned()
 39                    .unwrap_or_default(),
 40                format: prompt_format,
 41            });
 42        }
 43        PromptFormat::Zeta2 => {
 44            step_progress.set_substatus("loading project");
 45            run_load_project(example, app_state, cx.clone()).await?;
 46
 47            step_progress.set_substatus("formatting zeta2 prompt");
 48
 49            let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
 50                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 51            })?;
 52
 53            let state = example.state.as_ref().context("state must be set")?;
 54            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
 55            let project = state.project.clone();
 56            let (_, input) =
 57                ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
 58                    let events = ep_store
 59                        .edit_history_for_project(&project, cx)
 60                        .into_iter()
 61                        .map(|e| e.event)
 62                        .collect();
 63                    anyhow::Ok(zeta2_prompt_input(
 64                        &snapshot,
 65                        example
 66                            .context
 67                            .as_ref()
 68                            .context("context must be set")?
 69                            .files
 70                            .clone(),
 71                        events,
 72                        example.spec.cursor_path.clone(),
 73                        example
 74                            .buffer
 75                            .as_ref()
 76                            .context("buffer must be set")?
 77                            .cursor_offset,
 78                    ))
 79                })?;
 80            let prompt = format_zeta_prompt(&input);
 81            let expected_output = zeta2_output_for_patch(
 82                &input,
 83                &example
 84                    .spec
 85                    .expected_patches
 86                    .first()
 87                    .context("expected patches is empty")?
 88                    .clone(),
 89            )?;
 90            example.prompt = Some(ExamplePrompt {
 91                input: prompt,
 92                expected_output,
 93                format: prompt_format,
 94            });
 95        }
 96    };
 97    Ok(())
 98}
 99
100pub struct TeacherPrompt;
101
102impl TeacherPrompt {
103    const PROMPT: &str = include_str!("teacher.prompt.md");
104    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
105    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
106    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
107
108    /// Truncate edit history to this number of last lines
109    const MAX_HISTORY_LINES: usize = 128;
110
111    pub fn format_prompt(example: &Example) -> String {
112        let edit_history = Self::format_edit_history(&example.spec.edit_history);
113        let context = Self::format_context(example);
114        let editable_region = Self::format_editable_region(example);
115
116        let prompt = Self::PROMPT
117            .replace("{{context}}", &context)
118            .replace("{{edit_history}}", &edit_history)
119            .replace("{{editable_region}}", &editable_region);
120
121        prompt
122    }
123
124    pub fn parse(example: &Example, response: &str) -> Result<String> {
125        // Ideally, we should always be able to find cursor position in the retrieved context.
126        // In reality, sometimes we don't find it for these reasons:
127        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
128        //    (can be fixed by getting cursor coordinates at the load_example stage)
129        // 2. Context retriever just didn't include cursor line.
130        //
131        // In that case, fallback to using `cursor_position` as excerpt.
132        let example_buffer = example
133            .buffer
134            .as_ref()
135            .context("`buffer` should be filled in in the context collection step")?;
136        let cursor_file = &example_buffer.content;
137
138        // Extract updated (new) editable region from the model response.
139        // The model may include editable region markers in its output, so we need to strip them.
140        let new_editable_region = extract_last_codeblock(response);
141        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
142
143        let old_editable_region =
144            example_buffer.content[example_buffer.editable_range.clone()].to_string();
145
146        // Normalize leading newlines: if old starts with newline but new doesn't,
147        // prepend newline to new to preserve whitespace structure.
148        // This handles the case where the model drops the leading blank line.
149        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
150            new_editable_region.insert(0, '\n');
151        }
152
153        ensure!(
154            cursor_file.contains(&old_editable_region),
155            "Something's wrong: editable_region is not found in the cursor file"
156        );
157
158        // Apply editable region to a larger context and compute diff.
159        // This is needed to get a better context lines around the editable region
160        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
161        let diff = language::unified_diff(&cursor_file, &edited_file);
162
163        let diff = indoc::formatdoc! {"
164            --- a/{path}
165            +++ b/{path}
166            {diff}",
167            path = example.spec.cursor_path.to_string_lossy(),
168            diff = diff,
169        };
170
171        Ok(diff)
172    }
173
174    fn format_edit_history(edit_history: &str) -> String {
175        // Strip comments ("garbage lines") from edit history
176        let lines = edit_history
177            .lines()
178            .filter(|&s| Self::is_udiff_content_line(s))
179            .collect::<Vec<_>>();
180
181        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
182            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
183        } else {
184            &lines
185        };
186
187        if history_lines.is_empty() {
188            return "(No edit history)".to_string();
189        }
190
191        history_lines.join("\n")
192    }
193
194    fn format_context(example: &Example) -> String {
195        assert!(example.context.is_some(), "Missing context retriever step");
196
197        let mut prompt = String::new();
198        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
199
200        prompt
201    }
202
203    fn format_editable_region(example: &Example) -> String {
204        let mut result = String::new();
205
206        let example_buffer = example.buffer.as_ref().unwrap();
207
208        let path_str = example.spec.cursor_path.to_string_lossy();
209        result.push_str(&format!("`````path=\"{path_str}\"\n"));
210        result.push_str(
211            &example_buffer.content
212                [example_buffer.context_range.start..example_buffer.editable_range.start],
213        );
214        result.push_str(Self::EDITABLE_REGION_START);
215        result.push_str(
216            &example_buffer.content
217                [example_buffer.editable_range.start..example_buffer.cursor_offset],
218        );
219        result.push_str(Self::USER_CURSOR_MARKER);
220        result.push_str(
221            &example_buffer.content
222                [example_buffer.cursor_offset..example_buffer.editable_range.end],
223        );
224        result.push_str(Self::EDITABLE_REGION_END);
225        result.push_str(
226            &example_buffer.content
227                [example_buffer.editable_range.end..example_buffer.context_range.end],
228        );
229        result.push_str("\n`````");
230
231        result
232    }
233
234    fn extract_editable_region(text: &str) -> String {
235        let start = text
236            .find(Self::EDITABLE_REGION_START)
237            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
238        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
239
240        let region = &text[start..end];
241        let region = region.strip_suffix('\n').unwrap_or(region);
242
243        region.replace("<|user_cursor|>", "")
244    }
245
246    fn is_udiff_content_line(s: &str) -> bool {
247        s.starts_with("-")
248            || s.starts_with("+")
249            || s.starts_with(" ")
250            || s.starts_with("---")
251            || s.starts_with("+++")
252            || s.starts_with("@@")
253    }
254}
255
256fn extract_last_codeblock(text: &str) -> String {
257    let mut last_block = None;
258    let mut search_start = 0;
259
260    while let Some(start) = text[search_start..].find("```") {
261        let start = start + search_start;
262        let bytes = text.as_bytes();
263        let mut backtick_end = start;
264
265        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
266            backtick_end += 1;
267        }
268
269        let backtick_count = backtick_end - start;
270        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
271
272        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
273            backtick_end += 1;
274        }
275
276        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
277            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
278            last_block = Some(code_block.to_string());
279            search_start = backtick_end + end_pos + closing_pattern.len();
280        } else {
281            break;
282        }
283    }
284
285    last_block.unwrap_or_else(|| text.to_string())
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_extract_last_code_block() {
294        let text = indoc::indoc! {"
295            Some thinking
296
297            ```
298            first block
299            ```
300
301            `````path='something' lines=1:2
302            last block
303            `````
304            "};
305        let last_block = extract_last_codeblock(text);
306        assert_eq!(last_block, "last block\n");
307    }
308
309    #[test]
310    fn test_extract_codeblock_with_nested_fences() {
311        let text = indoc::indoc! {"
312            `````
313            content with ``` inline
314            and ```python nested
315            more content
316            `````
317            "};
318        let last_block = extract_last_codeblock(text);
319        assert_eq!(
320            last_block,
321            "content with ``` inline\nand ```python nested\nmore content\n"
322        );
323    }
324
325    #[test]
326    fn test_extract_codeblock_ignores_inline_backticks() {
327        let text = indoc::indoc! {"
328            `````
329            here is some `code` with inline backticks
330            and here```more```stuff
331            `````
332            "};
333        let last_block = extract_last_codeblock(text);
334        assert_eq!(
335            last_block,
336            "here is some `code` with inline backticks\nand here```more```stuff\n"
337        );
338    }
339
340    #[test]
341    fn test_extract_editable_region() {
342        let text = indoc::indoc! {"
343            some lines
344            are
345            here
346            <|editable_region_start|>
347            one
348            two three
349
350            <|editable_region_end|>
351            more
352            lines here
353            "};
354        let parsed = TeacherPrompt::extract_editable_region(text);
355        assert_eq!(
356            parsed,
357            indoc::indoc! {"
358            one
359            two three
360            "}
361        );
362    }
363
364    #[test]
365    fn test_extract_last_codeblock_nested_bibtex() {
366        let text = indoc::indoc! {r#"
367            Looking at the edit history, I can see that a Citation section was just added.
368
369            `````
370            ## Collaborations
371            Our mission is to create a 4D generative model.
372
373            ## Citation
374
375            If you found Unique3D helpful, please cite our report:
376            ```bibtex
377            @misc{wu2024unique3d,
378                  title={Unique3D},
379            }
380            ```
381            `````
382            "#};
383        let last_block = extract_last_codeblock(text);
384        assert_eq!(
385            last_block,
386            indoc::indoc! {r#"
387            ## Collaborations
388            Our mission is to create a 4D generative model.
389
390            ## Citation
391
392            If you found Unique3D helpful, please cite our report:
393            ```bibtex
394            @misc{wu2024unique3d,
395                  title={Unique3D},
396            }
397            ```
398            "#}
399        );
400    }
401
402    #[test]
403    fn test_extract_editable_region_no_markers() {
404        let text = indoc::indoc! {"
405            one
406            two three
407            "};
408        let parsed = TeacherPrompt::extract_editable_region(text);
409        assert_eq!(
410            parsed,
411            indoc::indoc! {"
412            one
413            two three"}
414        );
415    }
416
417    #[test]
418    fn test_extract_editable_region_strips_cursor_marker() {
419        let text = indoc::indoc! {"
420            <|editable_region_start|>
421            one
422            <|user_cursor|>two three
423
424            <|editable_region_end|>
425            "};
426        let parsed = TeacherPrompt::extract_editable_region(text);
427        assert_eq!(
428            parsed,
429            indoc::indoc! {"
430            one
431            two three
432            "}
433        );
434    }
435}