format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result, anyhow};
  9use edit_prediction::udiff;
 10use gpui::AsyncApp;
 11use similar::DiffableStr;
 12use std::sync::Arc;
 13use std::{fmt::Write as _, ops::Range};
 14use zeta_prompt::{
 15    ZetaFormat, excerpt_range_for_format, format_zeta_prompt, resolve_cursor_region,
 16};
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    match args.provider {
 35        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 36            step_progress.set_substatus("formatting teacher prompt");
 37
 38            let zeta_format = ZetaFormat::default();
 39            let excerpt_ranges = prompt_inputs
 40                .excerpt_ranges
 41                .as_ref()
 42                .context("prompt_inputs must have excerpt_ranges")?;
 43            let (editable_range, context_range) =
 44                excerpt_range_for_format(zeta_format, excerpt_ranges);
 45
 46            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 47            example.prompt = Some(ExamplePrompt {
 48                input: prompt,
 49                expected_output: String::new(),
 50                rejected_output: None,
 51                prefill: None,
 52                provider: args.provider,
 53            });
 54        }
 55        PredictionProvider::Zeta2(zeta_format) => {
 56            step_progress.set_substatus("formatting zeta2 prompt");
 57
 58            let prompt = format_zeta_prompt(prompt_inputs, zeta_format);
 59            let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format);
 60            let (expected_patch, expected_cursor_offset) = example
 61                .spec
 62                .expected_patches_with_cursor_positions()
 63                .into_iter()
 64                .next()
 65                .context("expected patches is empty")?;
 66            let expected_output = zeta2_output_for_patch(
 67                prompt_inputs,
 68                &expected_patch,
 69                expected_cursor_offset,
 70                zeta_format,
 71            )?;
 72            let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
 73                zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
 74            });
 75
 76            example.prompt = Some(ExamplePrompt {
 77                input: prompt,
 78                expected_output,
 79                rejected_output,
 80                provider: args.provider,
 81                prefill: Some(prefill),
 82            });
 83        }
 84        _ => {
 85            panic!("Cannot format prompt for {:?}", args.provider);
 86        }
 87    };
 88    Ok(())
 89}
 90
 91pub fn zeta2_output_for_patch(
 92    input: &zeta_prompt::ZetaPromptInput,
 93    patch: &str,
 94    cursor_offset: Option<usize>,
 95    version: ZetaFormat,
 96) -> Result<String> {
 97    let (context, editable_range, _) = resolve_cursor_region(input, version);
 98    let mut old_editable_region = context[editable_range].to_string();
 99
100    if !old_editable_region.ends_with_newline() {
101        old_editable_region.push('\n');
102    }
103
104    let (mut result, first_hunk_offset) =
105        udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
106            || {
107                format!(
108                    "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
109                    patch, old_editable_region
110                )
111            },
112        )?;
113
114    if let Some(cursor_offset) = cursor_offset {
115        // The cursor_offset is relative to the start of the hunk's new text (context + additions).
116        // We need to add where the hunk context matched in the editable region to compute
117        // the actual cursor position in the result.
118        let hunk_start = first_hunk_offset.unwrap_or(0);
119        let offset = result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()));
120        result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
121    }
122
123    match version {
124        ZetaFormat::V0120GitMergeMarkers
125        | ZetaFormat::V0131GitMergeMarkersPrefix
126        | ZetaFormat::V0211SeedCoder => {
127            if !result.ends_with('\n') {
128                result.push('\n');
129            }
130            result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
131        }
132        _ => (),
133    }
134
135    Ok(result)
136}
137
138pub struct TeacherPrompt;
139
140impl TeacherPrompt {
141    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
142    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
143    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
144    pub(crate) const NO_EDITS: &str = "NO_EDITS";
145
146    /// Truncate edit history to this number of last lines
147    const MAX_HISTORY_LINES: usize = 128;
148
149    pub fn format_prompt(
150        example: &Example,
151        editable_range: Range<usize>,
152        context_range: Range<usize>,
153    ) -> String {
154        let edit_history = Self::format_edit_history(&example.spec.edit_history);
155        let context = Self::format_context(example);
156        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
157
158        let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
159        let prompt = prompt_template
160            .replace("{{context}}", &context)
161            .replace("{{edit_history}}", &edit_history)
162            .replace("{{cursor_excerpt}}", &cursor_excerpt);
163
164        prompt
165    }
166
167    pub fn parse(example: &Example, response: &str) -> Result<(String, Option<ActualCursor>)> {
168        // Check if the model indicated no edits are needed
169        let no_edits = (String::new(), None);
170        if let Some(last_codeblock) = extract_last_codeblock(&response) {
171            if last_codeblock.trim() == Self::NO_EDITS {
172                return Ok(no_edits);
173            }
174        }
175
176        if response.trim().ends_with(Self::NO_EDITS) {
177            return Ok(no_edits);
178        }
179
180        // Extract updated (new) editable region from the model response.
181        let new_editable_region = Self::extract_editable_region(&response)?;
182        let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER);
183        let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, "");
184        let old_editable_region = Self::extract_editable_region(
185            &example
186                .prompt
187                .as_ref()
188                .context("example prompt missing")?
189                .input,
190        )?
191        .replace(Self::USER_CURSOR_MARKER, "");
192
193        let prompt_inputs = example
194            .prompt_inputs
195            .as_ref()
196            .context("example is missing prompt inputs")?;
197
198        // Normalize leading newlines: if old starts with newline but new doesn't,
199        // prepend newline to new to preserve whitespace structure.
200        // This handles the case where the model drops the leading blank line.
201        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
202            new_editable_region.insert(0, '\n');
203        }
204
205        let excerpt = prompt_inputs.cursor_excerpt.as_ref();
206        let (editable_region_offset, _) = excerpt
207            .match_indices(&old_editable_region)
208            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset_in_excerpt))
209            .context("editable region not found in prompt content")?;
210        let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count();
211
212        // Use full context so cursor offset (relative to editable region start) aligns with diff content
213        let editable_region_lines = old_editable_region.lines().count() as u32;
214        let diff = language::unified_diff_with_context(
215            &old_editable_region,
216            &new_editable_region,
217            editable_region_start_line as u32,
218            editable_region_start_line as u32,
219            editable_region_lines,
220        );
221
222        let diff = indoc::formatdoc! {"
223            --- a/{path}
224            +++ b/{path}
225            {diff}",
226            path = example.spec.cursor_path.to_string_lossy(),
227            diff = diff,
228        };
229
230        let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
231            ActualCursor::from_editable_region(
232                &example.spec.cursor_path,
233                editable_region_cursor_offset,
234                &new_editable_region,
235                excerpt,
236                editable_region_offset,
237                editable_region_start_line,
238            )
239        });
240
241        Ok((diff, actual_cursor))
242    }
243
244    fn format_edit_history(edit_history: &str) -> String {
245        let lines: Vec<&str> = edit_history.lines().collect();
246
247        if lines.is_empty() {
248            return "(No edit history)".to_string();
249        }
250
251        if lines.len() > Self::MAX_HISTORY_LINES {
252            let truncated = lines[lines.len() - Self::MAX_HISTORY_LINES..].join("\n");
253            format!("{truncated}\n[...truncated...]")
254        } else {
255            lines.join("\n")
256        }
257    }
258
259    pub fn format_context(example: &Example) -> String {
260        let related_files = example.prompt_inputs.as_ref().map(|pi| &pi.related_files);
261
262        let Some(related_files) = related_files else {
263            return "(No context)".to_string();
264        };
265
266        if related_files.is_empty() {
267            return "(No context)".to_string();
268        }
269
270        let mut prompt = String::new();
271        for file in related_files {
272            let path_str = file.path.to_string_lossy();
273            writeln!(&mut prompt, "`````{path_str}").ok();
274
275            let mut prev_row = 0;
276            for excerpt in &file.excerpts {
277                if excerpt.row_range.start > prev_row {
278                    prompt.push_str("\n");
279                }
280                prompt.push_str(&excerpt.text);
281                prompt.push('\n');
282                prev_row = excerpt.row_range.end;
283            }
284            if prev_row < file.max_row {
285                prompt.push_str("\n");
286            }
287            prompt.push_str("\n`````\n");
288        }
289
290        prompt
291    }
292
293    fn format_cursor_excerpt(
294        example: &Example,
295        editable_range: Range<usize>,
296        context_range: Range<usize>,
297    ) -> String {
298        let mut result = String::new();
299
300        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
301        let excerpt = prompt_inputs.cursor_excerpt.as_ref();
302        let cursor_offset = prompt_inputs.cursor_offset_in_excerpt;
303
304        let path_str = example.spec.cursor_path.to_string_lossy();
305        result.push_str(&format!("`````{path_str}\n"));
306        result.push_str(&excerpt[context_range.start..editable_range.start]);
307        result.push_str(Self::EDITABLE_REGION_START);
308        result.push_str(&excerpt[editable_range.start..cursor_offset]);
309        result.push_str(Self::USER_CURSOR_MARKER);
310        result.push_str(&excerpt[cursor_offset..editable_range.end]);
311        result.push_str(Self::EDITABLE_REGION_END);
312        result.push_str(&excerpt[editable_range.end..context_range.end]);
313        result.push_str("\n`````");
314
315        result
316    }
317
318    pub fn extract_editable_region(text: &str) -> Result<String> {
319        let start = text
320            .rfind(Self::EDITABLE_REGION_START)
321            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
322        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
323
324        if start >= end {
325            return Err(anyhow!("Invalid editable region markers"));
326        }
327
328        let region = &text[start..end];
329        Ok(region.strip_suffix('\n').unwrap_or(region).to_string())
330    }
331}
332
333/// Extract the cursor excerpt from an example.
334/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
335pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
336    // If we have the original prompt, extract the cursor excerpt from it
337    if let Some(prompt) = &example.prompt {
338        // Find "# 3. Current File" section and extract the content
339        if let Some(start) = prompt.input.find("# 3. Current File") {
340            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
341            let backtick_count = prompt.input[content_start..]
342                .chars()
343                .take_while(|&c| c == '`')
344                .count();
345            let content_start = content_start + backtick_count;
346
347            // Find the path line and skip it
348            let newline_pos = prompt.input[content_start..].find('\n')?;
349            let text_start = content_start + newline_pos + 1;
350
351            // Find the closing backticks
352            let closing_pattern = "`".repeat(backtick_count);
353            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
354            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
355
356            let path_str = example.spec.cursor_path.to_string_lossy();
357            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
358        }
359    }
360
361    // Fallback: construct from prompt_inputs if available
362    let prompt_inputs = example.prompt_inputs.as_ref()?;
363    let excerpt = prompt_inputs.cursor_excerpt.as_ref();
364    let cursor_offset = prompt_inputs.cursor_offset_in_excerpt;
365
366    // Simple fallback: just show content around cursor with markers
367    let path_str = example.spec.cursor_path.to_string_lossy();
368    let mut result = format!("`````{path_str}\n");
369    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
370    result.push_str(&excerpt[..cursor_offset]);
371    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
372    result.push_str(&excerpt[cursor_offset..]);
373    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
374    result.push_str("\n`````");
375
376    Some(result)
377}
378
379pub(crate) fn extract_last_codeblock(text: &str) -> Option<String> {
380    let lines: Vec<&str> = text.lines().collect();
381
382    // Search from the end for a closing fence (line containing only backticks, 3+)
383    let mut closing_line_idx = None;
384    let mut backtick_count = 0;
385
386    for i in (0..lines.len()).rev() {
387        let line = lines[i].trim();
388        if line.len() >= 3 && line.chars().all(|c| c == '`') {
389            closing_line_idx = Some(i);
390            backtick_count = line.len();
391            break;
392        }
393    }
394
395    let closing_idx = closing_line_idx?;
396
397    // Search backwards for matching opening fence
398    // Opening fence starts with same backtick count, possibly followed by language/metadata
399    let opening_pattern = "`".repeat(backtick_count);
400
401    for i in (0..closing_idx).rev() {
402        let line = lines[i];
403        if line.starts_with(&opening_pattern) {
404            // Ensure it's exactly the right number of backticks (not more)
405            let rest = &line[backtick_count..];
406            if rest.is_empty() || !rest.starts_with('`') {
407                // Found matching opening fence
408                // Extract content between opening and closing (exclusive)
409                if closing_idx > i + 1 {
410                    let content = lines[i + 1..closing_idx].join("\n");
411                    // Preserve trailing newline to match previous behavior
412                    return Some(format!("{}\n", content));
413                } else {
414                    // Empty block
415                    return Some(String::new());
416                }
417            }
418        }
419    }
420
421    None
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_extract_last_code_block() {
430        let text = indoc::indoc! {"
431            Some thinking
432
433            ```
434            first block
435            ```
436
437            `````path='something' lines=1:2
438            last block
439            `````
440            "};
441        let last_block = extract_last_codeblock(text).unwrap();
442        assert_eq!(last_block, "last block\n");
443    }
444
445    #[test]
446    fn test_extract_codeblock_with_nested_fences() {
447        let text = indoc::indoc! {"
448            `````
449            content with ``` inline
450            and ```python nested
451            more content
452            `````
453            "};
454        let last_block = extract_last_codeblock(text).unwrap();
455        assert_eq!(
456            last_block,
457            "content with ``` inline\nand ```python nested\nmore content\n"
458        );
459    }
460
461    #[test]
462    fn test_extract_codeblock_ignores_inline_backticks() {
463        let text = indoc::indoc! {"
464            `````
465            here is some `code` with inline backticks
466            and here```more```stuff
467            `````
468            "};
469        let last_block = extract_last_codeblock(text).unwrap();
470        assert_eq!(
471            last_block,
472            "here is some `code` with inline backticks\nand here```more```stuff\n"
473        );
474    }
475
476    #[test]
477    fn test_extract_editable_region() {
478        let text = indoc::indoc! {"
479            some lines
480            are
481            here
482            <|editable_region_start|>
483            one
484            two three
485
486            <|editable_region_end|>
487            more
488            lines here
489            "};
490        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
491        assert_eq!(
492            parsed,
493            indoc::indoc! {"
494            one
495            two three"}
496        );
497    }
498
499    #[test]
500    fn test_extract_last_codeblock_nested_bibtex() {
501        let text = indoc::indoc! {r#"
502            Looking at the edit history, I can see that a Citation section was just added.
503
504            `````
505            ## Collaborations
506            Our mission is to create a 4D generative model.
507
508            ## Citation
509
510            If you found Unique3D helpful, please cite our report:
511            ```bibtex
512            @misc{wu2024unique3d,
513                  title={Unique3D},
514            }
515            ```
516            `````
517            "#};
518        let last_block = extract_last_codeblock(text).unwrap();
519        assert_eq!(
520            last_block,
521            indoc::indoc! {r#"
522            ## Collaborations
523            Our mission is to create a 4D generative model.
524
525            ## Citation
526
527            If you found Unique3D helpful, please cite our report:
528            ```bibtex
529            @misc{wu2024unique3d,
530                  title={Unique3D},
531            }
532            ```
533            "#}
534        );
535    }
536
537    #[test]
538    fn test_extract_editable_region_no_markers() {
539        let text = indoc::indoc! {"
540            one
541            two three"};
542        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
543        assert_eq!(
544            parsed,
545            indoc::indoc! {"
546            one
547            two three"}
548        );
549    }
550
551    #[test]
552    fn test_parse_no_edits_response() {
553        let response = indoc::indoc! {"
554            The code is already complete. There is no clear next edit to make.
555
556            `````
557            NO_EDITS
558            `````
559        "};
560        let codeblock = extract_last_codeblock(response).unwrap();
561        assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
562    }
563
564    #[test]
565    fn test_extract_codeblock_no_valid_block() {
566        // Text with no code blocks should return None
567        let text = "Just some plain text without any code blocks";
568        assert!(extract_last_codeblock(text).is_none());
569
570        // Unclosed code block should return None
571        let text = indoc::indoc! {"
572            ```
573            unclosed block
574        "};
575        assert!(extract_last_codeblock(text).is_none());
576
577        // Analysis text with nested markdown but no proper outer block
578        let text = indoc::indoc! {"
579            # Analysis
580            Looking at this:
581            ```
582            some code
583            ```
584            But then more analysis without wrapping block
585        "};
586        // This should find the inner block
587        let result = extract_last_codeblock(text).unwrap();
588        assert_eq!(result, "some code\n");
589    }
590
591    #[test]
592    fn test_extract_codeblock_no_trailing_newline() {
593        // Text ending without trailing newline after closing fence
594        let text = "`````\ncontent here\n`````";
595        let result = extract_last_codeblock(text).unwrap();
596        assert_eq!(result, "content here\n");
597    }
598}