format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{Progress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use gpui::AsyncApp;
 10use similar::DiffableStr;
 11use std::fmt::Write as _;
 12use std::sync::Arc;
 13use zeta_prompt::format_zeta_prompt;
 14
 15pub async fn run_format_prompt(
 16    example: &mut Example,
 17    args: &FormatPromptArgs,
 18    app_state: Arc<EpAppState>,
 19    cx: AsyncApp,
 20) -> Result<()> {
 21    run_context_retrieval(example, app_state, cx).await?;
 22
 23    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 24
 25    let prompt_inputs = example
 26        .prompt_inputs
 27        .as_ref()
 28        .context("prompt_inputs must be set after context retrieval")?;
 29
 30    match args.provider {
 31        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 32            step_progress.set_substatus("formatting teacher prompt");
 33            let prompt = TeacherPrompt::format_prompt(example);
 34            example.prompt = Some(ExamplePrompt {
 35                input: prompt,
 36                expected_output: example
 37                    .spec
 38                    .expected_patches
 39                    .first()
 40                    .cloned()
 41                    .unwrap_or_default(),
 42                provider: args.provider,
 43            });
 44        }
 45        PredictionProvider::Zeta2 => {
 46            step_progress.set_substatus("formatting zeta2 prompt");
 47
 48            let context_start = prompt_inputs.context_range.start;
 49            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 50            let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
 51                ..(prompt_inputs.editable_range.end - context_start);
 52            let input = zeta_prompt::ZetaPromptInput {
 53                cursor_path: example.spec.cursor_path.clone(),
 54                cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
 55                    .to_string()
 56                    .into(),
 57                editable_range_in_excerpt,
 58                cursor_offset_in_excerpt,
 59                events: prompt_inputs.edit_history.clone(),
 60                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 61            };
 62            let prompt = format_zeta_prompt(&input, args.version);
 63            let expected_output = zeta2_output_for_patch(
 64                &input,
 65                &example
 66                    .spec
 67                    .expected_patches
 68                    .first()
 69                    .context("expected patches is empty")?
 70                    .clone(),
 71            )?;
 72            example.prompt = Some(ExamplePrompt {
 73                input: prompt,
 74                expected_output,
 75                provider: args.provider,
 76            });
 77        }
 78        _ => {
 79            panic!("Cannot format prompt for {:?}", args.provider);
 80        }
 81    };
 82    Ok(())
 83}
 84
 85pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
 86    let mut old_editable_region =
 87        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
 88
 89    if !old_editable_region.ends_with_newline() {
 90        old_editable_region.push('\n');
 91    }
 92
 93    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
 94        format!(
 95            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
 96            patch, old_editable_region
 97        )
 98    })
 99}
100
101pub struct TeacherPrompt;
102
103impl TeacherPrompt {
104    const PROMPT: &str = include_str!("teacher.prompt.md");
105    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
106    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
107    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
108
109    /// Truncate edit history to this number of last lines
110    const MAX_HISTORY_LINES: usize = 128;
111
112    pub fn format_prompt(example: &Example) -> String {
113        let edit_history = Self::format_edit_history(&example.spec.edit_history);
114        let context = Self::format_context(example);
115        let cursor_excerpt = Self::format_cursor_excerpt(example);
116
117        let prompt = Self::PROMPT
118            .replace("{{context}}", &context)
119            .replace("{{edit_history}}", &edit_history)
120            .replace("{{cursor_excerpt}}", &cursor_excerpt);
121
122        prompt
123    }
124
125    pub fn parse(example: &Example, response: &str) -> Result<String> {
126        // Ideally, we should always be able to find cursor position in the retrieved context.
127        // In reality, sometimes we don't find it for these reasons:
128        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
129        //    (can be fixed by getting cursor coordinates at the load_example stage)
130        // 2. Context retriever just didn't include cursor line.
131        //
132        // In that case, fallback to using `cursor_position` as excerpt.
133        let prompt_inputs = example
134            .prompt_inputs
135            .as_ref()
136            .context("`prompt_inputs` should be filled in in the context collection step")?;
137
138        // Extract updated (new) editable region from the model response.
139        // The model may include editable region markers in its output, so we need to strip them.
140        let new_editable_region = extract_last_codeblock(response);
141        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
142
143        let old_editable_region =
144            prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
145
146        // Normalize leading newlines: if old starts with newline but new doesn't,
147        // prepend newline to new to preserve whitespace structure.
148        // This handles the case where the model drops the leading blank line.
149        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
150            new_editable_region.insert(0, '\n');
151        }
152
153        let editable_region_start_line = prompt_inputs.content
154            [..prompt_inputs.editable_range.start]
155            .matches('\n')
156            .count();
157
158        let diff = language::unified_diff_with_offsets(
159            &old_editable_region,
160            &new_editable_region,
161            editable_region_start_line as u32,
162            editable_region_start_line as u32,
163        );
164
165        let diff = indoc::formatdoc! {"
166            --- a/{path}
167            +++ b/{path}
168            {diff}",
169            path = example.spec.cursor_path.to_string_lossy(),
170            diff = diff,
171        };
172
173        Ok(diff)
174    }
175
176    fn format_edit_history(edit_history: &str) -> String {
177        // Strip comments ("garbage lines") from edit history
178        let lines = edit_history
179            .lines()
180            .filter(|&s| Self::is_udiff_content_line(s))
181            .collect::<Vec<_>>();
182
183        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
184            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
185        } else {
186            &lines
187        };
188
189        if history_lines.is_empty() {
190            return "(No edit history)".to_string();
191        }
192
193        history_lines.join("\n")
194    }
195
196    fn format_context(example: &Example) -> String {
197        let related_files = example
198            .prompt_inputs
199            .as_ref()
200            .and_then(|pi| pi.related_files.as_ref());
201
202        let Some(related_files) = related_files else {
203            return "(No context)".to_string();
204        };
205
206        if related_files.is_empty() {
207            return "(No context)".to_string();
208        }
209
210        let mut prompt = String::new();
211        for file in related_files {
212            let path_str = file.path.to_string_lossy();
213            writeln!(&mut prompt, "`````{path_str}").ok();
214            let mut prev_row = 0;
215            for excerpt in &file.excerpts {
216                if excerpt.row_range.start > prev_row {
217                    prompt.push_str("\n");
218                }
219                prompt.push_str(&excerpt.text);
220                prompt.push('\n');
221                prev_row = excerpt.row_range.end;
222            }
223            if prev_row < file.max_row {
224                prompt.push_str("\n");
225            }
226            prompt.push_str("\n`````");
227        }
228
229        prompt
230    }
231
232    fn format_cursor_excerpt(example: &Example) -> String {
233        let mut result = String::new();
234
235        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
236
237        let path_str = example.spec.cursor_path.to_string_lossy();
238        result.push_str(&format!("`````{path_str}\n"));
239        result.push_str(
240            &prompt_inputs.content
241                [prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
242        );
243        result.push_str(Self::EDITABLE_REGION_START);
244        result.push_str(
245            &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
246        );
247        result.push_str(Self::USER_CURSOR_MARKER);
248        result.push_str(
249            &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
250        );
251        result.push_str(Self::EDITABLE_REGION_END);
252        result.push_str(
253            &prompt_inputs.content
254                [prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
255        );
256        result.push_str("\n`````");
257
258        result
259    }
260
261    fn extract_editable_region(text: &str) -> String {
262        let start = text
263            .find(Self::EDITABLE_REGION_START)
264            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
265        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
266
267        let region = &text[start..end];
268        let region = region.strip_suffix('\n').unwrap_or(region);
269
270        region.replace(Self::USER_CURSOR_MARKER, "")
271    }
272
273    fn is_udiff_content_line(s: &str) -> bool {
274        s.starts_with("-")
275            || s.starts_with("+")
276            || s.starts_with(" ")
277            || s.starts_with("---")
278            || s.starts_with("+++")
279            || s.starts_with("@@")
280    }
281}
282
283fn extract_last_codeblock(text: &str) -> String {
284    let mut last_block = None;
285    let mut search_start = 0;
286
287    while let Some(start) = text[search_start..].find("```") {
288        let start = start + search_start;
289        let bytes = text.as_bytes();
290        let mut backtick_end = start;
291
292        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
293            backtick_end += 1;
294        }
295
296        let backtick_count = backtick_end - start;
297        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
298
299        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
300            backtick_end += 1;
301        }
302
303        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
304            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
305            last_block = Some(code_block.to_string());
306            search_start = backtick_end + end_pos + closing_pattern.len();
307        } else {
308            break;
309        }
310    }
311
312    last_block.unwrap_or_else(|| text.to_string())
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_extract_last_code_block() {
321        let text = indoc::indoc! {"
322            Some thinking
323
324            ```
325            first block
326            ```
327
328            `````path='something' lines=1:2
329            last block
330            `````
331            "};
332        let last_block = extract_last_codeblock(text);
333        assert_eq!(last_block, "last block\n");
334    }
335
336    #[test]
337    fn test_extract_codeblock_with_nested_fences() {
338        let text = indoc::indoc! {"
339            `````
340            content with ``` inline
341            and ```python nested
342            more content
343            `````
344            "};
345        let last_block = extract_last_codeblock(text);
346        assert_eq!(
347            last_block,
348            "content with ``` inline\nand ```python nested\nmore content\n"
349        );
350    }
351
352    #[test]
353    fn test_extract_codeblock_ignores_inline_backticks() {
354        let text = indoc::indoc! {"
355            `````
356            here is some `code` with inline backticks
357            and here```more```stuff
358            `````
359            "};
360        let last_block = extract_last_codeblock(text);
361        assert_eq!(
362            last_block,
363            "here is some `code` with inline backticks\nand here```more```stuff\n"
364        );
365    }
366
367    #[test]
368    fn test_extract_editable_region() {
369        let text = indoc::indoc! {"
370            some lines
371            are
372            here
373            <|editable_region_start|>
374            one
375            two three
376
377            <|editable_region_end|>
378            more
379            lines here
380            "};
381        let parsed = TeacherPrompt::extract_editable_region(text);
382        assert_eq!(
383            parsed,
384            indoc::indoc! {"
385            one
386            two three"}
387        );
388    }
389
390    #[test]
391    fn test_extract_last_codeblock_nested_bibtex() {
392        let text = indoc::indoc! {r#"
393            Looking at the edit history, I can see that a Citation section was just added.
394
395            `````
396            ## Collaborations
397            Our mission is to create a 4D generative model.
398
399            ## Citation
400
401            If you found Unique3D helpful, please cite our report:
402            ```bibtex
403            @misc{wu2024unique3d,
404                  title={Unique3D},
405            }
406            ```
407            `````
408            "#};
409        let last_block = extract_last_codeblock(text);
410        assert_eq!(
411            last_block,
412            indoc::indoc! {r#"
413            ## Collaborations
414            Our mission is to create a 4D generative model.
415
416            ## Citation
417
418            If you found Unique3D helpful, please cite our report:
419            ```bibtex
420            @misc{wu2024unique3d,
421                  title={Unique3D},
422            }
423            ```
424            "#}
425        );
426    }
427
428    #[test]
429    fn test_extract_editable_region_no_markers() {
430        let text = indoc::indoc! {"
431            one
432            two three"};
433        let parsed = TeacherPrompt::extract_editable_region(text);
434        assert_eq!(
435            parsed,
436            indoc::indoc! {"
437            one
438            two three"}
439        );
440    }
441
442    #[test]
443    fn test_extract_editable_region_strips_cursor_marker() {
444        let text = indoc::indoc! {"
445            <|editable_region_start|>
446            one
447            <|user_cursor|>two three
448
449            <|editable_region_end|>
450            "};
451        let parsed = TeacherPrompt::extract_editable_region(text);
452        assert_eq!(
453            parsed,
454            indoc::indoc! {"
455            one
456            two three"}
457        );
458    }
459}