format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result, anyhow};
  9use edit_prediction::{cursor_excerpt::compute_excerpt_ranges, udiff};
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::format_zeta_prompt;
 16use zeta_prompt::{ZetaFormat, excerpt_range_for_format};
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    let (_, _, excerpt_ranges) = compute_excerpt_ranges(cursor_point, &snapshot);
 51
 52    match args.provider {
 53        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 54            step_progress.set_substatus("formatting teacher prompt");
 55
 56            let zeta_format = ZetaFormat::default();
 57            let (editable_range, context_range) =
 58                excerpt_range_for_format(zeta_format, &excerpt_ranges);
 59
 60            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 61            example.prompt = Some(ExamplePrompt {
 62                input: prompt,
 63                expected_output: String::new(),
 64                rejected_output: None,
 65                prefill: None,
 66                provider: args.provider,
 67            });
 68        }
 69        PredictionProvider::Zeta2(zeta_format) => {
 70            step_progress.set_substatus("formatting zeta2 prompt");
 71
 72            let (editable_range, context_range) =
 73                excerpt_range_for_format(zeta_format, &excerpt_ranges);
 74
 75            let context_start = context_range.start;
 76            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 77            let editable_range_in_excerpt =
 78                (editable_range.start - context_start)..(editable_range.end - context_start);
 79            let input = zeta_prompt::ZetaPromptInput {
 80                cursor_path: example.spec.cursor_path.clone(),
 81                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 82                editable_range_in_excerpt,
 83                cursor_offset_in_excerpt,
 84                excerpt_start_row: prompt_inputs.excerpt_start_row,
 85                events: prompt_inputs.edit_history.clone(),
 86                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 87                excerpt_ranges: Some(excerpt_ranges),
 88                preferred_model: None,
 89                in_open_source_repo: example
 90                    .spec
 91                    .captured_prompt_input
 92                    .as_ref()
 93                    .map_or(false, |input| input.in_open_source_repo),
 94                can_collect_data: false,
 95            };
 96            let prompt = format_zeta_prompt(&input, zeta_format);
 97            let prefill = zeta_prompt::get_prefill(&input, zeta_format);
 98            let (expected_patch, expected_cursor_offset) = example
 99                .spec
100                .expected_patches_with_cursor_positions()
101                .into_iter()
102                .next()
103                .context("expected patches is empty")?;
104            let expected_output = zeta2_output_for_patch(
105                &input,
106                &expected_patch,
107                expected_cursor_offset,
108                zeta_format,
109            )?;
110            let rejected_output =
111                example.spec.rejected_patch.as_ref().and_then(|patch| {
112                    zeta2_output_for_patch(&input, patch, None, zeta_format).ok()
113                });
114
115            example.prompt = Some(ExamplePrompt {
116                input: prompt,
117                expected_output,
118                rejected_output,
119                provider: args.provider,
120                prefill: Some(prefill),
121            });
122        }
123        _ => {
124            panic!("Cannot format prompt for {:?}", args.provider);
125        }
126    };
127    Ok(())
128}
129
130pub fn zeta2_output_for_patch(
131    input: &zeta_prompt::ZetaPromptInput,
132    patch: &str,
133    cursor_offset: Option<usize>,
134    version: ZetaFormat,
135) -> Result<String> {
136    let mut old_editable_region =
137        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
138
139    if !old_editable_region.ends_with_newline() {
140        old_editable_region.push('\n');
141    }
142
143    let (mut result, first_hunk_offset) =
144        udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
145            || {
146                format!(
147                    "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
148                    patch, old_editable_region
149                )
150            },
151        )?;
152
153    if let Some(cursor_offset) = cursor_offset {
154        // The cursor_offset is relative to the start of the hunk's new text (context + additions).
155        // We need to add where the hunk context matched in the editable region to compute
156        // the actual cursor position in the result.
157        let hunk_start = first_hunk_offset.unwrap_or(0);
158        let offset = result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()));
159        result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
160    }
161
162    match version {
163        ZetaFormat::V0120GitMergeMarkers
164        | ZetaFormat::V0131GitMergeMarkersPrefix
165        | ZetaFormat::V0211SeedCoder => {
166            if !result.ends_with('\n') {
167                result.push('\n');
168            }
169            result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
170        }
171        _ => (),
172    }
173
174    Ok(result)
175}
176
177pub struct TeacherPrompt;
178
179impl TeacherPrompt {
180    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
181    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
182    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
183    pub(crate) const NO_EDITS: &str = "NO_EDITS";
184
185    /// Truncate edit history to this number of last lines
186    const MAX_HISTORY_LINES: usize = 128;
187
188    pub fn format_prompt(
189        example: &Example,
190        editable_range: Range<usize>,
191        context_range: Range<usize>,
192    ) -> String {
193        let edit_history = Self::format_edit_history(&example.spec.edit_history);
194        let context = Self::format_context(example);
195        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
196
197        let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
198        let prompt = prompt_template
199            .replace("{{context}}", &context)
200            .replace("{{edit_history}}", &edit_history)
201            .replace("{{cursor_excerpt}}", &cursor_excerpt);
202
203        prompt
204    }
205
206    pub fn parse(example: &Example, response: &str) -> Result<(String, Option<ActualCursor>)> {
207        // Check if the model indicated no edits are needed
208        if let Some(last_codeblock) = extract_last_codeblock(&response) {
209            if last_codeblock.trim() == Self::NO_EDITS {
210                return Ok((String::new(), None));
211            }
212        }
213
214        // Extract updated (new) editable region from the model response.
215        let new_editable_region = Self::extract_editable_region(&response)?;
216        let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER);
217        let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, "");
218        let old_editable_region = Self::extract_editable_region(
219            &example
220                .prompt
221                .as_ref()
222                .context("example prompt missing")?
223                .input,
224        )?
225        .replace(Self::USER_CURSOR_MARKER, "");
226
227        let prompt_inputs = example
228            .prompt_inputs
229            .as_ref()
230            .context("example is missing prompt inputs")?;
231
232        // Normalize leading newlines: if old starts with newline but new doesn't,
233        // prepend newline to new to preserve whitespace structure.
234        // This handles the case where the model drops the leading blank line.
235        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
236            new_editable_region.insert(0, '\n');
237        }
238
239        let (editable_region_offset, _) = prompt_inputs
240            .content
241            .match_indices(&old_editable_region)
242            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
243            .context("editable region not found in prompt content")?;
244        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
245            .matches('\n')
246            .count();
247
248        // Use full context so cursor offset (relative to editable region start) aligns with diff content
249        let editable_region_lines = old_editable_region.lines().count() as u32;
250        let diff = language::unified_diff_with_context(
251            &old_editable_region,
252            &new_editable_region,
253            editable_region_start_line as u32,
254            editable_region_start_line as u32,
255            editable_region_lines,
256        );
257
258        let diff = indoc::formatdoc! {"
259            --- a/{path}
260            +++ b/{path}
261            {diff}",
262            path = example.spec.cursor_path.to_string_lossy(),
263            diff = diff,
264        };
265
266        let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
267            ActualCursor::from_editable_region(
268                &example.spec.cursor_path,
269                editable_region_cursor_offset,
270                &new_editable_region,
271                &prompt_inputs.content,
272                editable_region_offset,
273                editable_region_start_line,
274            )
275        });
276
277        Ok((diff, actual_cursor))
278    }
279
280    fn format_edit_history(edit_history: &str) -> String {
281        let lines: Vec<&str> = edit_history.lines().collect();
282
283        if lines.is_empty() {
284            return "(No edit history)".to_string();
285        }
286
287        if lines.len() > Self::MAX_HISTORY_LINES {
288            let truncated = lines[lines.len() - Self::MAX_HISTORY_LINES..].join("\n");
289            format!("{truncated}\n[...truncated...]")
290        } else {
291            lines.join("\n")
292        }
293    }
294
295    pub fn format_context(example: &Example) -> String {
296        let related_files = example
297            .prompt_inputs
298            .as_ref()
299            .and_then(|pi| pi.related_files.as_ref());
300
301        let Some(related_files) = related_files else {
302            return "(No context)".to_string();
303        };
304
305        if related_files.is_empty() {
306            return "(No context)".to_string();
307        }
308
309        let mut prompt = String::new();
310        for file in related_files {
311            let path_str = file.path.to_string_lossy();
312            writeln!(&mut prompt, "`````{path_str}").ok();
313
314            let mut prev_row = 0;
315            for excerpt in &file.excerpts {
316                if excerpt.row_range.start > prev_row {
317                    prompt.push_str("\n");
318                }
319                prompt.push_str(&excerpt.text);
320                prompt.push('\n');
321                prev_row = excerpt.row_range.end;
322            }
323            if prev_row < file.max_row {
324                prompt.push_str("\n");
325            }
326            prompt.push_str("\n`````\n");
327        }
328
329        prompt
330    }
331
332    fn format_cursor_excerpt(
333        example: &Example,
334        editable_range: Range<usize>,
335        context_range: Range<usize>,
336    ) -> String {
337        let mut result = String::new();
338
339        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
340
341        let path_str = example.spec.cursor_path.to_string_lossy();
342        result.push_str(&format!("`````{path_str}\n"));
343        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
344        result.push_str(Self::EDITABLE_REGION_START);
345        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
346        result.push_str(Self::USER_CURSOR_MARKER);
347        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
348        result.push_str(Self::EDITABLE_REGION_END);
349        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
350        result.push_str("\n`````");
351
352        result
353    }
354
355    pub fn extract_editable_region(text: &str) -> Result<String> {
356        let start = text
357            .rfind(Self::EDITABLE_REGION_START)
358            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
359        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
360
361        if start >= end {
362            return Err(anyhow!("Invalid editable region markers"));
363        }
364
365        let region = &text[start..end];
366        Ok(region.strip_suffix('\n').unwrap_or(region).to_string())
367    }
368}
369
370/// Extract the cursor excerpt from an example.
371/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
372pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
373    // If we have the original prompt, extract the cursor excerpt from it
374    if let Some(prompt) = &example.prompt {
375        // Find "# 3. Current File" section and extract the content
376        if let Some(start) = prompt.input.find("# 3. Current File") {
377            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
378            let backtick_count = prompt.input[content_start..]
379                .chars()
380                .take_while(|&c| c == '`')
381                .count();
382            let content_start = content_start + backtick_count;
383
384            // Find the path line and skip it
385            let newline_pos = prompt.input[content_start..].find('\n')?;
386            let text_start = content_start + newline_pos + 1;
387
388            // Find the closing backticks
389            let closing_pattern = "`".repeat(backtick_count);
390            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
391            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
392
393            let path_str = example.spec.cursor_path.to_string_lossy();
394            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
395        }
396    }
397
398    // Fallback: construct from prompt_inputs if available
399    let prompt_inputs = example.prompt_inputs.as_ref()?;
400    let content = &prompt_inputs.content;
401    let cursor_offset = prompt_inputs.cursor_offset;
402
403    // Simple fallback: just show content around cursor with markers
404    let path_str = example.spec.cursor_path.to_string_lossy();
405    let mut result = format!("`````{path_str}\n");
406    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
407    result.push_str(&content[..cursor_offset]);
408    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
409    result.push_str(&content[cursor_offset..]);
410    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
411    result.push_str("\n`````");
412
413    Some(result)
414}
415
416pub(crate) fn extract_last_codeblock(text: &str) -> Option<String> {
417    let lines: Vec<&str> = text.lines().collect();
418
419    // Search from the end for a closing fence (line containing only backticks, 3+)
420    let mut closing_line_idx = None;
421    let mut backtick_count = 0;
422
423    for i in (0..lines.len()).rev() {
424        let line = lines[i].trim();
425        if line.len() >= 3 && line.chars().all(|c| c == '`') {
426            closing_line_idx = Some(i);
427            backtick_count = line.len();
428            break;
429        }
430    }
431
432    let closing_idx = closing_line_idx?;
433
434    // Search backwards for matching opening fence
435    // Opening fence starts with same backtick count, possibly followed by language/metadata
436    let opening_pattern = "`".repeat(backtick_count);
437
438    for i in (0..closing_idx).rev() {
439        let line = lines[i];
440        if line.starts_with(&opening_pattern) {
441            // Ensure it's exactly the right number of backticks (not more)
442            let rest = &line[backtick_count..];
443            if rest.is_empty() || !rest.starts_with('`') {
444                // Found matching opening fence
445                // Extract content between opening and closing (exclusive)
446                if closing_idx > i + 1 {
447                    let content = lines[i + 1..closing_idx].join("\n");
448                    // Preserve trailing newline to match previous behavior
449                    return Some(format!("{}\n", content));
450                } else {
451                    // Empty block
452                    return Some(String::new());
453                }
454            }
455        }
456    }
457
458    None
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_extract_last_code_block() {
467        let text = indoc::indoc! {"
468            Some thinking
469
470            ```
471            first block
472            ```
473
474            `````path='something' lines=1:2
475            last block
476            `````
477            "};
478        let last_block = extract_last_codeblock(text).unwrap();
479        assert_eq!(last_block, "last block\n");
480    }
481
482    #[test]
483    fn test_extract_codeblock_with_nested_fences() {
484        let text = indoc::indoc! {"
485            `````
486            content with ``` inline
487            and ```python nested
488            more content
489            `````
490            "};
491        let last_block = extract_last_codeblock(text).unwrap();
492        assert_eq!(
493            last_block,
494            "content with ``` inline\nand ```python nested\nmore content\n"
495        );
496    }
497
498    #[test]
499    fn test_extract_codeblock_ignores_inline_backticks() {
500        let text = indoc::indoc! {"
501            `````
502            here is some `code` with inline backticks
503            and here```more```stuff
504            `````
505            "};
506        let last_block = extract_last_codeblock(text).unwrap();
507        assert_eq!(
508            last_block,
509            "here is some `code` with inline backticks\nand here```more```stuff\n"
510        );
511    }
512
513    #[test]
514    fn test_extract_editable_region() {
515        let text = indoc::indoc! {"
516            some lines
517            are
518            here
519            <|editable_region_start|>
520            one
521            two three
522
523            <|editable_region_end|>
524            more
525            lines here
526            "};
527        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
528        assert_eq!(
529            parsed,
530            indoc::indoc! {"
531            one
532            two three"}
533        );
534    }
535
536    #[test]
537    fn test_extract_last_codeblock_nested_bibtex() {
538        let text = indoc::indoc! {r#"
539            Looking at the edit history, I can see that a Citation section was just added.
540
541            `````
542            ## Collaborations
543            Our mission is to create a 4D generative model.
544
545            ## Citation
546
547            If you found Unique3D helpful, please cite our report:
548            ```bibtex
549            @misc{wu2024unique3d,
550                  title={Unique3D},
551            }
552            ```
553            `````
554            "#};
555        let last_block = extract_last_codeblock(text).unwrap();
556        assert_eq!(
557            last_block,
558            indoc::indoc! {r#"
559            ## Collaborations
560            Our mission is to create a 4D generative model.
561
562            ## Citation
563
564            If you found Unique3D helpful, please cite our report:
565            ```bibtex
566            @misc{wu2024unique3d,
567                  title={Unique3D},
568            }
569            ```
570            "#}
571        );
572    }
573
574    #[test]
575    fn test_extract_editable_region_no_markers() {
576        let text = indoc::indoc! {"
577            one
578            two three"};
579        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
580        assert_eq!(
581            parsed,
582            indoc::indoc! {"
583            one
584            two three"}
585        );
586    }
587
588    #[test]
589    fn test_parse_no_edits_response() {
590        let response = indoc::indoc! {"
591            The code is already complete. There is no clear next edit to make.
592
593            `````
594            NO_EDITS
595            `````
596        "};
597        let codeblock = extract_last_codeblock(response).unwrap();
598        assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
599    }
600
601    #[test]
602    fn test_extract_codeblock_no_valid_block() {
603        // Text with no code blocks should return None
604        let text = "Just some plain text without any code blocks";
605        assert!(extract_last_codeblock(text).is_none());
606
607        // Unclosed code block should return None
608        let text = indoc::indoc! {"
609            ```
610            unclosed block
611        "};
612        assert!(extract_last_codeblock(text).is_none());
613
614        // Analysis text with nested markdown but no proper outer block
615        let text = indoc::indoc! {"
616            # Analysis
617            Looking at this:
618            ```
619            some code
620            ```
621            But then more analysis without wrapping block
622        "};
623        // This should find the inner block
624        let result = extract_last_codeblock(text).unwrap();
625        assert_eq!(result, "some code\n");
626    }
627
628    #[test]
629    fn test_extract_codeblock_no_trailing_newline() {
630        // Text ending without trailing newline after closing fence
631        let text = "`````\ncontent here\n`````";
632        let result = extract_last_codeblock(text).unwrap();
633        assert_eq!(result, "content here\n");
634    }
635}