format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaVersion;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            None,
 45            cx,
 46        )
 47    });
 48    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 49    let snapshot = cx.background_spawn(snapshot_fut).await;
 50
 51    match args.provider {
 52        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
 53            step_progress.set_substatus("formatting teacher prompt");
 54
 55            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 56                cursor_point,
 57                &snapshot,
 58                edit_prediction::zeta2::max_editable_tokens(version),
 59                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 60            );
 61            let editable_range = editable_range.to_offset(&snapshot);
 62            let context_range = context_range.to_offset(&snapshot);
 63
 64            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 65            example.prompt = Some(ExamplePrompt {
 66                input: prompt,
 67                expected_output: example
 68                    .spec
 69                    .expected_patches
 70                    .first()
 71                    .cloned()
 72                    .unwrap_or_default(),
 73                provider: args.provider,
 74            });
 75        }
 76        PredictionProvider::Zeta2(version) => {
 77            step_progress.set_substatus("formatting zeta2 prompt");
 78
 79            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 80                cursor_point,
 81                &snapshot,
 82                edit_prediction::zeta2::max_editable_tokens(version),
 83                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 84            );
 85            let editable_range = editable_range.to_offset(&snapshot);
 86            let context_range = context_range.to_offset(&snapshot);
 87
 88            let context_start = context_range.start;
 89            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 90            let editable_range_in_excerpt =
 91                (editable_range.start - context_start)..(editable_range.end - context_start);
 92            let input = zeta_prompt::ZetaPromptInput {
 93                cursor_path: example.spec.cursor_path.clone(),
 94                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 95                editable_range_in_excerpt,
 96                cursor_offset_in_excerpt,
 97                events: prompt_inputs.edit_history.clone(),
 98                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 99            };
100            let prompt = format_zeta_prompt(&input, version);
101            let expected_output = zeta2_output_for_patch(
102                &input,
103                &example
104                    .spec
105                    .expected_patches
106                    .first()
107                    .context("expected patches is empty")?
108                    .clone(),
109                version,
110            )?;
111            example.prompt = Some(ExamplePrompt {
112                input: prompt,
113                expected_output,
114                provider: args.provider,
115            });
116        }
117        _ => {
118            panic!("Cannot format prompt for {:?}", args.provider);
119        }
120    };
121    Ok(())
122}
123
124pub fn zeta2_output_for_patch(
125    input: &zeta_prompt::ZetaPromptInput,
126    patch: &str,
127    version: ZetaVersion,
128) -> Result<String> {
129    let mut old_editable_region =
130        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
131
132    if !old_editable_region.ends_with_newline() {
133        old_editable_region.push('\n');
134    }
135
136    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
137        .with_context(|| {
138            format!(
139                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
140                patch, old_editable_region
141            )
142        })?;
143
144    if version == ZetaVersion::V0120GitMergeMarkers {
145        if !result.ends_with('\n') {
146            result.push('\n');
147        }
148        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
149    }
150
151    Ok(result)
152}
153
154pub struct TeacherPrompt;
155
156impl TeacherPrompt {
157    const PROMPT: &str = include_str!("teacher.prompt.md");
158    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
159    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
160    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
161
162    /// Truncate edit history to this number of last lines
163    const MAX_HISTORY_LINES: usize = 128;
164
165    pub fn format_prompt(
166        example: &Example,
167        editable_range: Range<usize>,
168        context_range: Range<usize>,
169    ) -> String {
170        let edit_history = Self::format_edit_history(&example.spec.edit_history);
171        let context = Self::format_context(example);
172        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
173
174        let prompt = Self::PROMPT
175            .replace("{{context}}", &context)
176            .replace("{{edit_history}}", &edit_history)
177            .replace("{{cursor_excerpt}}", &cursor_excerpt);
178
179        prompt
180    }
181
182    pub fn parse(example: &Example, response: &str) -> Result<String> {
183        // Extract updated (new) editable region from the model response.
184        // The model may include editable region markers in its output, so we need to strip them.
185        let new_editable_region = extract_last_codeblock(response);
186        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
187        let old_editable_region = Self::extract_editable_region(
188            &example
189                .prompt
190                .as_ref()
191                .context("example prompt missing")?
192                .input,
193        );
194        let prompt_inputs = example
195            .prompt_inputs
196            .as_ref()
197            .context("example is missing prompt inputs")?;
198
199        // Normalize leading newlines: if old starts with newline but new doesn't,
200        // prepend newline to new to preserve whitespace structure.
201        // This handles the case where the model drops the leading blank line.
202        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
203            new_editable_region.insert(0, '\n');
204        }
205
206        let (editable_region_offset, _) = prompt_inputs
207            .content
208            .match_indices(&old_editable_region)
209            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
210            .unwrap();
211        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
212            .matches('\n')
213            .count();
214
215        let diff = language::unified_diff_with_offsets(
216            &old_editable_region,
217            &new_editable_region,
218            editable_region_start_line as u32,
219            editable_region_start_line as u32,
220        );
221
222        let diff = indoc::formatdoc! {"
223            --- a/{path}
224            +++ b/{path}
225            {diff}",
226            path = example.spec.cursor_path.to_string_lossy(),
227            diff = diff,
228        };
229
230        Ok(diff)
231    }
232
233    fn format_edit_history(edit_history: &str) -> String {
234        // Strip comments ("garbage lines") from edit history
235        let lines = edit_history
236            .lines()
237            .filter(|&s| Self::is_udiff_content_line(s))
238            .collect::<Vec<_>>();
239
240        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
241            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
242        } else {
243            &lines
244        };
245
246        if history_lines.is_empty() {
247            return "(No edit history)".to_string();
248        }
249
250        history_lines.join("\n")
251    }
252
253    fn format_context(example: &Example) -> String {
254        let related_files = example
255            .prompt_inputs
256            .as_ref()
257            .and_then(|pi| pi.related_files.as_ref());
258
259        let Some(related_files) = related_files else {
260            return "(No context)".to_string();
261        };
262
263        if related_files.is_empty() {
264            return "(No context)".to_string();
265        }
266
267        let mut prompt = String::new();
268        for file in related_files {
269            let path_str = file.path.to_string_lossy();
270            writeln!(&mut prompt, "`````{path_str}").ok();
271
272            let mut prev_row = 0;
273            for excerpt in &file.excerpts {
274                if excerpt.row_range.start > prev_row {
275                    prompt.push_str("\n");
276                }
277                prompt.push_str(&excerpt.text);
278                prompt.push('\n');
279                prev_row = excerpt.row_range.end;
280            }
281            if prev_row < file.max_row {
282                prompt.push_str("\n");
283            }
284            prompt.push_str("\n`````\n");
285        }
286
287        prompt
288    }
289
290    fn format_cursor_excerpt(
291        example: &Example,
292        editable_range: Range<usize>,
293        context_range: Range<usize>,
294    ) -> String {
295        let mut result = String::new();
296
297        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
298
299        let path_str = example.spec.cursor_path.to_string_lossy();
300        result.push_str(&format!("`````{path_str}\n"));
301        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
302        result.push_str(Self::EDITABLE_REGION_START);
303        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
304        result.push_str(Self::USER_CURSOR_MARKER);
305        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
306        result.push_str(Self::EDITABLE_REGION_END);
307        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
308        result.push_str("\n`````");
309
310        result
311    }
312
313    fn extract_editable_region(text: &str) -> String {
314        let start = text
315            .rfind(Self::EDITABLE_REGION_START)
316            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
317        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
318
319        let region = &text[start..end];
320        let region = region.strip_suffix('\n').unwrap_or(region);
321
322        region.replace(Self::USER_CURSOR_MARKER, "")
323    }
324
325    fn is_udiff_content_line(s: &str) -> bool {
326        s.starts_with("-")
327            || s.starts_with("+")
328            || s.starts_with(" ")
329            || s.starts_with("---")
330            || s.starts_with("+++")
331            || s.starts_with("@@")
332    }
333}
334
335fn extract_last_codeblock(text: &str) -> String {
336    let mut last_block = None;
337    let mut search_start = 0;
338
339    while let Some(start) = text[search_start..].find("```") {
340        let start = start + search_start;
341        let bytes = text.as_bytes();
342        let mut backtick_end = start;
343
344        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
345            backtick_end += 1;
346        }
347
348        let backtick_count = backtick_end - start;
349        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
350
351        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
352            backtick_end += 1;
353        }
354
355        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
356            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
357            last_block = Some(code_block.to_string());
358            search_start = backtick_end + end_pos + closing_pattern.len();
359        } else {
360            break;
361        }
362    }
363
364    last_block.unwrap_or_else(|| text.to_string())
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_extract_last_code_block() {
373        let text = indoc::indoc! {"
374            Some thinking
375
376            ```
377            first block
378            ```
379
380            `````path='something' lines=1:2
381            last block
382            `````
383            "};
384        let last_block = extract_last_codeblock(text);
385        assert_eq!(last_block, "last block\n");
386    }
387
388    #[test]
389    fn test_extract_codeblock_with_nested_fences() {
390        let text = indoc::indoc! {"
391            `````
392            content with ``` inline
393            and ```python nested
394            more content
395            `````
396            "};
397        let last_block = extract_last_codeblock(text);
398        assert_eq!(
399            last_block,
400            "content with ``` inline\nand ```python nested\nmore content\n"
401        );
402    }
403
404    #[test]
405    fn test_extract_codeblock_ignores_inline_backticks() {
406        let text = indoc::indoc! {"
407            `````
408            here is some `code` with inline backticks
409            and here```more```stuff
410            `````
411            "};
412        let last_block = extract_last_codeblock(text);
413        assert_eq!(
414            last_block,
415            "here is some `code` with inline backticks\nand here```more```stuff\n"
416        );
417    }
418
419    #[test]
420    fn test_extract_editable_region() {
421        let text = indoc::indoc! {"
422            some lines
423            are
424            here
425            <|editable_region_start|>
426            one
427            two three
428
429            <|editable_region_end|>
430            more
431            lines here
432            "};
433        let parsed = TeacherPrompt::extract_editable_region(text);
434        assert_eq!(
435            parsed,
436            indoc::indoc! {"
437            one
438            two three"}
439        );
440    }
441
442    #[test]
443    fn test_extract_last_codeblock_nested_bibtex() {
444        let text = indoc::indoc! {r#"
445            Looking at the edit history, I can see that a Citation section was just added.
446
447            `````
448            ## Collaborations
449            Our mission is to create a 4D generative model.
450
451            ## Citation
452
453            If you found Unique3D helpful, please cite our report:
454            ```bibtex
455            @misc{wu2024unique3d,
456                  title={Unique3D},
457            }
458            ```
459            `````
460            "#};
461        let last_block = extract_last_codeblock(text);
462        assert_eq!(
463            last_block,
464            indoc::indoc! {r#"
465            ## Collaborations
466            Our mission is to create a 4D generative model.
467
468            ## Citation
469
470            If you found Unique3D helpful, please cite our report:
471            ```bibtex
472            @misc{wu2024unique3d,
473                  title={Unique3D},
474            }
475            ```
476            "#}
477        );
478    }
479
480    #[test]
481    fn test_extract_editable_region_no_markers() {
482        let text = indoc::indoc! {"
483            one
484            two three"};
485        let parsed = TeacherPrompt::extract_editable_region(text);
486        assert_eq!(
487            parsed,
488            indoc::indoc! {"
489            one
490            two three"}
491        );
492    }
493
494    #[test]
495    fn test_extract_editable_region_strips_cursor_marker() {
496        let text = indoc::indoc! {"
497            <|editable_region_start|>
498            one
499            <|user_cursor|>two three
500
501            <|editable_region_end|>
502            "};
503        let parsed = TeacherPrompt::extract_editable_region(text);
504        assert_eq!(
505            parsed,
506            indoc::indoc! {"
507            one
508            two three"}
509        );
510    }
511}