format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{Progress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    args: &FormatPromptArgs,
 20    app_state: Arc<EpAppState>,
 21    cx: AsyncApp,
 22) -> Result<()> {
 23    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 24
 25    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 26
 27    let prompt_inputs = example
 28        .prompt_inputs
 29        .as_ref()
 30        .context("prompt_inputs must be set after context retrieval")?;
 31
 32    let language = app_state
 33        .languages
 34        .load_language_for_file_path(&example.spec.cursor_path)
 35        .await
 36        .ok();
 37    let snapshot_fut = cx.update(|cx| {
 38        Buffer::build_snapshot(
 39            prompt_inputs.content.as_str().into(),
 40            language,
 41            Some(app_state.languages.clone()),
 42            cx,
 43        )
 44    });
 45    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 46    let snapshot = cx.background_spawn(snapshot_fut).await;
 47
 48    match args.provider {
 49        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
 50            step_progress.set_substatus("formatting teacher prompt");
 51
 52            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 53                cursor_point,
 54                &snapshot,
 55                edit_prediction::zeta2::max_editable_tokens(version),
 56                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 57            );
 58            let editable_range = editable_range.to_offset(&snapshot);
 59            let context_range = context_range.to_offset(&snapshot);
 60
 61            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 62            example.prompt = Some(ExamplePrompt {
 63                input: prompt,
 64                expected_output: example
 65                    .spec
 66                    .expected_patches
 67                    .first()
 68                    .cloned()
 69                    .unwrap_or_default(),
 70                provider: args.provider,
 71            });
 72        }
 73        PredictionProvider::Zeta2(version) => {
 74            step_progress.set_substatus("formatting zeta2 prompt");
 75
 76            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 77                cursor_point,
 78                &snapshot,
 79                edit_prediction::zeta2::max_editable_tokens(version),
 80                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 81            );
 82            let editable_range = editable_range.to_offset(&snapshot);
 83            let context_range = context_range.to_offset(&snapshot);
 84
 85            let context_start = context_range.start;
 86            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 87            let editable_range_in_excerpt =
 88                (editable_range.start - context_start)..(editable_range.end - context_start);
 89            let input = zeta_prompt::ZetaPromptInput {
 90                cursor_path: example.spec.cursor_path.clone(),
 91                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 92                editable_range_in_excerpt,
 93                cursor_offset_in_excerpt,
 94                events: prompt_inputs.edit_history.clone(),
 95                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 96            };
 97            let prompt = format_zeta_prompt(&input, version);
 98            let expected_output = zeta2_output_for_patch(
 99                &input,
100                &example
101                    .spec
102                    .expected_patches
103                    .first()
104                    .context("expected patches is empty")?
105                    .clone(),
106            )?;
107            example.prompt = Some(ExamplePrompt {
108                input: prompt,
109                expected_output,
110                provider: args.provider,
111            });
112        }
113        _ => {
114            panic!("Cannot format prompt for {:?}", args.provider);
115        }
116    };
117    Ok(())
118}
119
120pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
121    let mut old_editable_region =
122        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
123
124    if !old_editable_region.ends_with_newline() {
125        old_editable_region.push('\n');
126    }
127
128    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
129        format!(
130            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
131            patch, old_editable_region
132        )
133    })
134}
135
136pub struct TeacherPrompt;
137
138impl TeacherPrompt {
139    const PROMPT: &str = include_str!("teacher.prompt.md");
140    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
141    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
142    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
143
144    /// Truncate edit history to this number of last lines
145    const MAX_HISTORY_LINES: usize = 128;
146
147    pub fn format_prompt(
148        example: &Example,
149        editable_range: Range<usize>,
150        context_range: Range<usize>,
151    ) -> String {
152        let edit_history = Self::format_edit_history(&example.spec.edit_history);
153        let context = Self::format_context(example);
154        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
155
156        let prompt = Self::PROMPT
157            .replace("{{context}}", &context)
158            .replace("{{edit_history}}", &edit_history)
159            .replace("{{cursor_excerpt}}", &cursor_excerpt);
160
161        prompt
162    }
163
164    pub fn parse(example: &Example, response: &str) -> Result<String> {
165        // Extract updated (new) editable region from the model response.
166        // The model may include editable region markers in its output, so we need to strip them.
167        let new_editable_region = extract_last_codeblock(response);
168        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
169        let old_editable_region = Self::extract_editable_region(
170            &example
171                .prompt
172                .as_ref()
173                .context("example prompt missing")?
174                .input,
175        );
176        let prompt_inputs = example
177            .prompt_inputs
178            .as_ref()
179            .context("example is missing prompt inputs")?;
180
181        // Normalize leading newlines: if old starts with newline but new doesn't,
182        // prepend newline to new to preserve whitespace structure.
183        // This handles the case where the model drops the leading blank line.
184        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
185            new_editable_region.insert(0, '\n');
186        }
187
188        let (editable_region_offset, _) = prompt_inputs
189            .content
190            .match_indices(&old_editable_region)
191            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
192            .unwrap();
193        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
194            .matches('\n')
195            .count();
196
197        let diff = language::unified_diff_with_offsets(
198            &old_editable_region,
199            &new_editable_region,
200            editable_region_start_line as u32,
201            editable_region_start_line as u32,
202        );
203
204        let diff = indoc::formatdoc! {"
205            --- a/{path}
206            +++ b/{path}
207            {diff}",
208            path = example.spec.cursor_path.to_string_lossy(),
209            diff = diff,
210        };
211
212        Ok(diff)
213    }
214
215    fn format_edit_history(edit_history: &str) -> String {
216        // Strip comments ("garbage lines") from edit history
217        let lines = edit_history
218            .lines()
219            .filter(|&s| Self::is_udiff_content_line(s))
220            .collect::<Vec<_>>();
221
222        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
223            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
224        } else {
225            &lines
226        };
227
228        if history_lines.is_empty() {
229            return "(No edit history)".to_string();
230        }
231
232        history_lines.join("\n")
233    }
234
235    fn format_context(example: &Example) -> String {
236        let related_files = example
237            .prompt_inputs
238            .as_ref()
239            .and_then(|pi| pi.related_files.as_ref());
240
241        let Some(related_files) = related_files else {
242            return "(No context)".to_string();
243        };
244
245        if related_files.is_empty() {
246            return "(No context)".to_string();
247        }
248
249        let mut prompt = String::new();
250        for file in related_files {
251            let path_str = file.path.to_string_lossy();
252            writeln!(&mut prompt, "`````{path_str}").ok();
253            let mut prev_row = 0;
254            for excerpt in &file.excerpts {
255                if excerpt.row_range.start > prev_row {
256                    prompt.push_str("\n");
257                }
258                prompt.push_str(&excerpt.text);
259                prompt.push('\n');
260                prev_row = excerpt.row_range.end;
261            }
262            if prev_row < file.max_row {
263                prompt.push_str("\n");
264            }
265            prompt.push_str("\n`````");
266        }
267
268        prompt
269    }
270
271    fn format_cursor_excerpt(
272        example: &Example,
273        editable_range: Range<usize>,
274        context_range: Range<usize>,
275    ) -> String {
276        let mut result = String::new();
277
278        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
279
280        let path_str = example.spec.cursor_path.to_string_lossy();
281        result.push_str(&format!("`````{path_str}\n"));
282        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
283        result.push_str(Self::EDITABLE_REGION_START);
284        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
285        result.push_str(Self::USER_CURSOR_MARKER);
286        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
287        result.push_str(Self::EDITABLE_REGION_END);
288        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
289        result.push_str("\n`````");
290
291        result
292    }
293
294    fn extract_editable_region(text: &str) -> String {
295        let start = text
296            .find(Self::EDITABLE_REGION_START)
297            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
298        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
299
300        let region = &text[start..end];
301        let region = region.strip_suffix('\n').unwrap_or(region);
302
303        region.replace(Self::USER_CURSOR_MARKER, "")
304    }
305
306    fn is_udiff_content_line(s: &str) -> bool {
307        s.starts_with("-")
308            || s.starts_with("+")
309            || s.starts_with(" ")
310            || s.starts_with("---")
311            || s.starts_with("+++")
312            || s.starts_with("@@")
313    }
314}
315
316fn extract_last_codeblock(text: &str) -> String {
317    let mut last_block = None;
318    let mut search_start = 0;
319
320    while let Some(start) = text[search_start..].find("```") {
321        let start = start + search_start;
322        let bytes = text.as_bytes();
323        let mut backtick_end = start;
324
325        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
326            backtick_end += 1;
327        }
328
329        let backtick_count = backtick_end - start;
330        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
331
332        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
333            backtick_end += 1;
334        }
335
336        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
337            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
338            last_block = Some(code_block.to_string());
339            search_start = backtick_end + end_pos + closing_pattern.len();
340        } else {
341            break;
342        }
343    }
344
345    last_block.unwrap_or_else(|| text.to_string())
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_extract_last_code_block() {
354        let text = indoc::indoc! {"
355            Some thinking
356
357            ```
358            first block
359            ```
360
361            `````path='something' lines=1:2
362            last block
363            `````
364            "};
365        let last_block = extract_last_codeblock(text);
366        assert_eq!(last_block, "last block\n");
367    }
368
369    #[test]
370    fn test_extract_codeblock_with_nested_fences() {
371        let text = indoc::indoc! {"
372            `````
373            content with ``` inline
374            and ```python nested
375            more content
376            `````
377            "};
378        let last_block = extract_last_codeblock(text);
379        assert_eq!(
380            last_block,
381            "content with ``` inline\nand ```python nested\nmore content\n"
382        );
383    }
384
385    #[test]
386    fn test_extract_codeblock_ignores_inline_backticks() {
387        let text = indoc::indoc! {"
388            `````
389            here is some `code` with inline backticks
390            and here```more```stuff
391            `````
392            "};
393        let last_block = extract_last_codeblock(text);
394        assert_eq!(
395            last_block,
396            "here is some `code` with inline backticks\nand here```more```stuff\n"
397        );
398    }
399
400    #[test]
401    fn test_extract_editable_region() {
402        let text = indoc::indoc! {"
403            some lines
404            are
405            here
406            <|editable_region_start|>
407            one
408            two three
409
410            <|editable_region_end|>
411            more
412            lines here
413            "};
414        let parsed = TeacherPrompt::extract_editable_region(text);
415        assert_eq!(
416            parsed,
417            indoc::indoc! {"
418            one
419            two three"}
420        );
421    }
422
423    #[test]
424    fn test_extract_last_codeblock_nested_bibtex() {
425        let text = indoc::indoc! {r#"
426            Looking at the edit history, I can see that a Citation section was just added.
427
428            `````
429            ## Collaborations
430            Our mission is to create a 4D generative model.
431
432            ## Citation
433
434            If you found Unique3D helpful, please cite our report:
435            ```bibtex
436            @misc{wu2024unique3d,
437                  title={Unique3D},
438            }
439            ```
440            `````
441            "#};
442        let last_block = extract_last_codeblock(text);
443        assert_eq!(
444            last_block,
445            indoc::indoc! {r#"
446            ## Collaborations
447            Our mission is to create a 4D generative model.
448
449            ## Citation
450
451            If you found Unique3D helpful, please cite our report:
452            ```bibtex
453            @misc{wu2024unique3d,
454                  title={Unique3D},
455            }
456            ```
457            "#}
458        );
459    }
460
461    #[test]
462    fn test_extract_editable_region_no_markers() {
463        let text = indoc::indoc! {"
464            one
465            two three"};
466        let parsed = TeacherPrompt::extract_editable_region(text);
467        assert_eq!(
468            parsed,
469            indoc::indoc! {"
470            one
471            two three"}
472        );
473    }
474
475    #[test]
476    fn test_extract_editable_region_strips_cursor_marker() {
477        let text = indoc::indoc! {"
478            <|editable_region_start|>
479            one
480            <|user_cursor|>two three
481
482            <|editable_region_end|>
483            "};
484        let parsed = TeacherPrompt::extract_editable_region(text);
485        assert_eq!(
486            parsed,
487            indoc::indoc! {"
488            one
489            two three"}
490        );
491    }
492}