format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{ActualCursor, Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result, anyhow};
  9use edit_prediction::{cursor_excerpt::editable_and_context_ranges_for_cursor_position, udiff};
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaFormat;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    match args.provider {
 51        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 52            step_progress.set_substatus("formatting teacher prompt");
 53
 54            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 55                cursor_point,
 56                &snapshot,
 57                edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()),
 58                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 59            );
 60            let editable_range = editable_range.to_offset(&snapshot);
 61            let context_range = context_range.to_offset(&snapshot);
 62
 63            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 64            example.prompt = Some(ExamplePrompt {
 65                input: prompt,
 66                expected_output: String::new(),
 67                rejected_output: None,
 68                provider: args.provider,
 69            });
 70        }
 71        PredictionProvider::Zeta2(version) => {
 72            step_progress.set_substatus("formatting zeta2 prompt");
 73
 74            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 75                cursor_point,
 76                &snapshot,
 77                edit_prediction::zeta2::max_editable_tokens(version),
 78                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 79            );
 80            let editable_range = editable_range.to_offset(&snapshot);
 81            let context_range = context_range.to_offset(&snapshot);
 82
 83            let context_start = context_range.start;
 84            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 85            let editable_range_in_excerpt =
 86                (editable_range.start - context_start)..(editable_range.end - context_start);
 87            let input = zeta_prompt::ZetaPromptInput {
 88                cursor_path: example.spec.cursor_path.clone(),
 89                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 90                editable_range_in_excerpt,
 91                cursor_offset_in_excerpt,
 92                excerpt_start_row: prompt_inputs.excerpt_start_row,
 93                events: prompt_inputs.edit_history.clone(),
 94                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 95            };
 96            let prompt = format_zeta_prompt(&input, version);
 97            let (expected_patch, expected_cursor_offset) = example
 98                .spec
 99                .expected_patches_with_cursor_positions()
100                .into_iter()
101                .next()
102                .context("expected patches is empty")?;
103            let expected_output =
104                zeta2_output_for_patch(&input, &expected_patch, expected_cursor_offset, version)?;
105            let rejected_output = example
106                .spec
107                .rejected_patch
108                .as_ref()
109                .and_then(|patch| zeta2_output_for_patch(&input, patch, None, version).ok());
110
111            example.prompt = Some(ExamplePrompt {
112                input: prompt,
113                expected_output,
114                rejected_output,
115                provider: args.provider,
116            });
117        }
118        _ => {
119            panic!("Cannot format prompt for {:?}", args.provider);
120        }
121    };
122    Ok(())
123}
124
125pub fn zeta2_output_for_patch(
126    input: &zeta_prompt::ZetaPromptInput,
127    patch: &str,
128    cursor_offset: Option<usize>,
129    version: ZetaFormat,
130) -> Result<String> {
131    let mut old_editable_region =
132        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
133
134    if !old_editable_region.ends_with_newline() {
135        old_editable_region.push('\n');
136    }
137
138    let (mut result, first_hunk_offset) =
139        udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
140            || {
141                format!(
142                    "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
143                    patch, old_editable_region
144                )
145            },
146        )?;
147
148    if let Some(cursor_offset) = cursor_offset {
149        // The cursor_offset is relative to the start of the hunk's new text (context + additions).
150        // We need to add where the hunk context matched in the editable region to compute
151        // the actual cursor position in the result.
152        let hunk_start = first_hunk_offset.unwrap_or(0);
153        let offset = (hunk_start + cursor_offset).min(result.len());
154        result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
155    }
156
157    match version {
158        ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => {
159            if !result.ends_with('\n') {
160                result.push('\n');
161            }
162            result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
163        }
164        _ => (),
165    }
166
167    Ok(result)
168}
169
170pub struct TeacherPrompt;
171
172impl TeacherPrompt {
173    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
174    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
175    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
176    pub(crate) const NO_EDITS: &str = "NO_EDITS";
177
178    /// Truncate edit history to this number of last lines
179    const MAX_HISTORY_LINES: usize = 128;
180
181    pub fn format_prompt(
182        example: &Example,
183        editable_range: Range<usize>,
184        context_range: Range<usize>,
185    ) -> String {
186        let edit_history = Self::format_edit_history(&example.spec.edit_history);
187        let context = Self::format_context(example);
188        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
189
190        let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
191        let prompt = prompt_template
192            .replace("{{context}}", &context)
193            .replace("{{edit_history}}", &edit_history)
194            .replace("{{cursor_excerpt}}", &cursor_excerpt);
195
196        prompt
197    }
198
199    pub fn parse(example: &Example, response: &str) -> Result<(String, Option<ActualCursor>)> {
200        // Extract updated (new) editable region from the model response.
201        // The model may include editable region markers in its output, so we need to strip them.
202        let new_editable_region = extract_last_codeblock(response);
203
204        // Check if the model indicated no edits are needed
205        if new_editable_region.trim() == Self::NO_EDITS {
206            return Ok((String::new(), None));
207        }
208
209        let new_editable_region = Self::extract_editable_region(&new_editable_region)?;
210        let cursor_offset = new_editable_region.find(Self::USER_CURSOR_MARKER);
211        let mut new_editable_region = new_editable_region.replace(Self::USER_CURSOR_MARKER, "");
212        let old_editable_region = Self::extract_editable_region(
213            &example
214                .prompt
215                .as_ref()
216                .context("example prompt missing")?
217                .input,
218        )?
219        .replace(Self::USER_CURSOR_MARKER, "");
220
221        let prompt_inputs = example
222            .prompt_inputs
223            .as_ref()
224            .context("example is missing prompt inputs")?;
225
226        // Normalize leading newlines: if old starts with newline but new doesn't,
227        // prepend newline to new to preserve whitespace structure.
228        // This handles the case where the model drops the leading blank line.
229        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
230            new_editable_region.insert(0, '\n');
231        }
232
233        let (editable_region_offset, _) = prompt_inputs
234            .content
235            .match_indices(&old_editable_region)
236            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
237            .context("editable region not found in prompt content")?;
238        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
239            .matches('\n')
240            .count();
241
242        // Use full context so cursor offset (relative to editable region start) aligns with diff content
243        let editable_region_lines = old_editable_region.lines().count() as u32;
244        let diff = language::unified_diff_with_context(
245            &old_editable_region,
246            &new_editable_region,
247            editable_region_start_line as u32,
248            editable_region_start_line as u32,
249            editable_region_lines,
250        );
251
252        let diff = indoc::formatdoc! {"
253            --- a/{path}
254            +++ b/{path}
255            {diff}",
256            path = example.spec.cursor_path.to_string_lossy(),
257            diff = diff,
258        };
259
260        let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
261            ActualCursor::from_editable_region(
262                &example.spec.cursor_path,
263                editable_region_cursor_offset,
264                &new_editable_region,
265                &prompt_inputs.content,
266                editable_region_offset,
267                editable_region_start_line,
268            )
269        });
270
271        Ok((diff, actual_cursor))
272    }
273
274    fn format_edit_history(edit_history: &str) -> String {
275        // Strip comments ("garbage lines") from edit history
276        let lines = edit_history
277            .lines()
278            .filter(|&s| Self::is_udiff_content_line(s))
279            .collect::<Vec<_>>();
280
281        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
282            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
283        } else {
284            &lines
285        };
286
287        if history_lines.is_empty() {
288            return "(No edit history)".to_string();
289        }
290
291        history_lines.join("\n")
292    }
293
294    pub fn format_context(example: &Example) -> String {
295        let related_files = example
296            .prompt_inputs
297            .as_ref()
298            .and_then(|pi| pi.related_files.as_ref());
299
300        let Some(related_files) = related_files else {
301            return "(No context)".to_string();
302        };
303
304        if related_files.is_empty() {
305            return "(No context)".to_string();
306        }
307
308        let mut prompt = String::new();
309        for file in related_files {
310            let path_str = file.path.to_string_lossy();
311            writeln!(&mut prompt, "`````{path_str}").ok();
312
313            let mut prev_row = 0;
314            for excerpt in &file.excerpts {
315                if excerpt.row_range.start > prev_row {
316                    prompt.push_str("\n");
317                }
318                prompt.push_str(&excerpt.text);
319                prompt.push('\n');
320                prev_row = excerpt.row_range.end;
321            }
322            if prev_row < file.max_row {
323                prompt.push_str("\n");
324            }
325            prompt.push_str("\n`````\n");
326        }
327
328        prompt
329    }
330
331    fn format_cursor_excerpt(
332        example: &Example,
333        editable_range: Range<usize>,
334        context_range: Range<usize>,
335    ) -> String {
336        let mut result = String::new();
337
338        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
339
340        let path_str = example.spec.cursor_path.to_string_lossy();
341        result.push_str(&format!("`````{path_str}\n"));
342        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
343        result.push_str(Self::EDITABLE_REGION_START);
344        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
345        result.push_str(Self::USER_CURSOR_MARKER);
346        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
347        result.push_str(Self::EDITABLE_REGION_END);
348        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
349        result.push_str("\n`````");
350
351        result
352    }
353
354    pub fn extract_editable_region(text: &str) -> Result<String> {
355        let start = text
356            .rfind(Self::EDITABLE_REGION_START)
357            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
358        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
359
360        if start >= end {
361            return Err(anyhow!("Invalid editable region markers"));
362        }
363
364        let region = &text[start..end];
365        Ok(region.strip_suffix('\n').unwrap_or(region).to_string())
366    }
367
368    fn is_udiff_content_line(s: &str) -> bool {
369        s.starts_with("-")
370            || s.starts_with("+")
371            || s.starts_with(" ")
372            || s.starts_with("---")
373            || s.starts_with("+++")
374            || s.starts_with("@@")
375    }
376}
377
378/// Extract the cursor excerpt from an example.
379/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
380pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
381    // If we have the original prompt, extract the cursor excerpt from it
382    if let Some(prompt) = &example.prompt {
383        // Find "# 3. Current File" section and extract the content
384        if let Some(start) = prompt.input.find("# 3. Current File") {
385            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
386            let backtick_count = prompt.input[content_start..]
387                .chars()
388                .take_while(|&c| c == '`')
389                .count();
390            let content_start = content_start + backtick_count;
391
392            // Find the path line and skip it
393            let newline_pos = prompt.input[content_start..].find('\n')?;
394            let text_start = content_start + newline_pos + 1;
395
396            // Find the closing backticks
397            let closing_pattern = "`".repeat(backtick_count);
398            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
399            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
400
401            let path_str = example.spec.cursor_path.to_string_lossy();
402            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
403        }
404    }
405
406    // Fallback: construct from prompt_inputs if available
407    let prompt_inputs = example.prompt_inputs.as_ref()?;
408    let content = &prompt_inputs.content;
409    let cursor_offset = prompt_inputs.cursor_offset;
410
411    // Simple fallback: just show content around cursor with markers
412    let path_str = example.spec.cursor_path.to_string_lossy();
413    let mut result = format!("`````{path_str}\n");
414    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
415    result.push_str(&content[..cursor_offset]);
416    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
417    result.push_str(&content[cursor_offset..]);
418    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
419    result.push_str("\n`````");
420
421    Some(result)
422}
423
424pub(crate) fn extract_last_codeblock(text: &str) -> String {
425    let mut last_block = None;
426    let mut search_start = 0;
427
428    while let Some(start) = text[search_start..].find("```") {
429        let start = start + search_start;
430        let bytes = text.as_bytes();
431        let mut backtick_end = start;
432
433        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
434            backtick_end += 1;
435        }
436
437        let backtick_count = backtick_end - start;
438        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
439
440        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
441            backtick_end += 1;
442        }
443
444        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
445            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
446            last_block = Some(code_block.to_string());
447            search_start = backtick_end + end_pos + closing_pattern.len();
448        } else {
449            break;
450        }
451    }
452
453    last_block.unwrap_or_else(|| text.to_string())
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_extract_last_code_block() {
462        let text = indoc::indoc! {"
463            Some thinking
464
465            ```
466            first block
467            ```
468
469            `````path='something' lines=1:2
470            last block
471            `````
472            "};
473        let last_block = extract_last_codeblock(text);
474        assert_eq!(last_block, "last block\n");
475    }
476
477    #[test]
478    fn test_extract_codeblock_with_nested_fences() {
479        let text = indoc::indoc! {"
480            `````
481            content with ``` inline
482            and ```python nested
483            more content
484            `````
485            "};
486        let last_block = extract_last_codeblock(text);
487        assert_eq!(
488            last_block,
489            "content with ``` inline\nand ```python nested\nmore content\n"
490        );
491    }
492
493    #[test]
494    fn test_extract_codeblock_ignores_inline_backticks() {
495        let text = indoc::indoc! {"
496            `````
497            here is some `code` with inline backticks
498            and here```more```stuff
499            `````
500            "};
501        let last_block = extract_last_codeblock(text);
502        assert_eq!(
503            last_block,
504            "here is some `code` with inline backticks\nand here```more```stuff\n"
505        );
506    }
507
508    #[test]
509    fn test_extract_editable_region() {
510        let text = indoc::indoc! {"
511            some lines
512            are
513            here
514            <|editable_region_start|>
515            one
516            two three
517
518            <|editable_region_end|>
519            more
520            lines here
521            "};
522        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
523        assert_eq!(
524            parsed,
525            indoc::indoc! {"
526            one
527            two three"}
528        );
529    }
530
531    #[test]
532    fn test_extract_last_codeblock_nested_bibtex() {
533        let text = indoc::indoc! {r#"
534            Looking at the edit history, I can see that a Citation section was just added.
535
536            `````
537            ## Collaborations
538            Our mission is to create a 4D generative model.
539
540            ## Citation
541
542            If you found Unique3D helpful, please cite our report:
543            ```bibtex
544            @misc{wu2024unique3d,
545                  title={Unique3D},
546            }
547            ```
548            `````
549            "#};
550        let last_block = extract_last_codeblock(text);
551        assert_eq!(
552            last_block,
553            indoc::indoc! {r#"
554            ## Collaborations
555            Our mission is to create a 4D generative model.
556
557            ## Citation
558
559            If you found Unique3D helpful, please cite our report:
560            ```bibtex
561            @misc{wu2024unique3d,
562                  title={Unique3D},
563            }
564            ```
565            "#}
566        );
567    }
568
569    #[test]
570    fn test_extract_editable_region_no_markers() {
571        let text = indoc::indoc! {"
572            one
573            two three"};
574        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
575        assert_eq!(
576            parsed,
577            indoc::indoc! {"
578            one
579            two three"}
580        );
581    }
582
583    #[test]
584    fn test_parse_no_edits_response() {
585        let response = indoc::indoc! {"
586            The code is already complete. There is no clear next edit to make.
587
588            `````
589            NO_EDITS
590            `````
591        "};
592        let codeblock = extract_last_codeblock(response);
593        assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
594    }
595}