format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result, anyhow};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaVersion;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    match args.provider {
 51        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 52            step_progress.set_substatus("formatting teacher prompt");
 53
 54            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 55                cursor_point,
 56                &snapshot,
 57                edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
 58                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 59            );
 60            let editable_range = editable_range.to_offset(&snapshot);
 61            let context_range = context_range.to_offset(&snapshot);
 62
 63            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 64            let expected_output = example
 65                .spec
 66                .expected_patches
 67                .first()
 68                .cloned()
 69                .unwrap_or_default();
 70            let rejected_output = example.spec.rejected_patch.clone();
 71
 72            example.prompt = Some(ExamplePrompt {
 73                input: prompt,
 74                expected_output,
 75                rejected_output,
 76                provider: args.provider,
 77            });
 78        }
 79        PredictionProvider::Zeta2(version) => {
 80            step_progress.set_substatus("formatting zeta2 prompt");
 81
 82            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 83                cursor_point,
 84                &snapshot,
 85                edit_prediction::zeta2::max_editable_tokens(version),
 86                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 87            );
 88            let editable_range = editable_range.to_offset(&snapshot);
 89            let context_range = context_range.to_offset(&snapshot);
 90
 91            let context_start = context_range.start;
 92            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 93            let editable_range_in_excerpt =
 94                (editable_range.start - context_start)..(editable_range.end - context_start);
 95            let input = zeta_prompt::ZetaPromptInput {
 96                cursor_path: example.spec.cursor_path.clone(),
 97                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 98                editable_range_in_excerpt,
 99                cursor_offset_in_excerpt,
100                excerpt_start_row: prompt_inputs.excerpt_start_row,
101                events: prompt_inputs.edit_history.clone(),
102                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
103            };
104            let prompt = format_zeta_prompt(&input, version);
105            let expected_output = zeta2_output_for_patch(
106                &input,
107                &example
108                    .spec
109                    .expected_patches
110                    .first()
111                    .context("expected patches is empty")?
112                    .clone(),
113                version,
114            )?;
115            let rejected_output = example
116                .spec
117                .rejected_patch
118                .as_ref()
119                .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
120
121            example.prompt = Some(ExamplePrompt {
122                input: prompt,
123                expected_output,
124                rejected_output,
125                provider: args.provider,
126            });
127        }
128        _ => {
129            panic!("Cannot format prompt for {:?}", args.provider);
130        }
131    };
132    Ok(())
133}
134
135pub fn zeta2_output_for_patch(
136    input: &zeta_prompt::ZetaPromptInput,
137    patch: &str,
138    version: ZetaVersion,
139) -> Result<String> {
140    let mut old_editable_region =
141        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
142
143    if !old_editable_region.ends_with_newline() {
144        old_editable_region.push('\n');
145    }
146
147    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
148        .with_context(|| {
149            format!(
150                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
151                patch, old_editable_region
152            )
153        })?;
154
155    if version == ZetaVersion::V0120GitMergeMarkers {
156        if !result.ends_with('\n') {
157            result.push('\n');
158        }
159        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
160    }
161
162    Ok(result)
163}
164
165pub struct TeacherPrompt;
166
167impl TeacherPrompt {
168    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
169    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
170    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
171    pub(crate) const NO_EDITS: &str = "NO_EDITS";
172
173    /// Truncate edit history to this number of last lines
174    const MAX_HISTORY_LINES: usize = 128;
175
176    pub fn format_prompt(
177        example: &Example,
178        editable_range: Range<usize>,
179        context_range: Range<usize>,
180    ) -> String {
181        let edit_history = Self::format_edit_history(&example.spec.edit_history);
182        let context = Self::format_context(example);
183        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
184
185        let prompt_template = crate::prompt_assets::get_prompt("teacher.md");
186        let prompt = prompt_template
187            .replace("{{context}}", &context)
188            .replace("{{edit_history}}", &edit_history)
189            .replace("{{cursor_excerpt}}", &cursor_excerpt);
190
191        prompt
192    }
193
194    pub fn parse(example: &Example, response: &str) -> Result<String> {
195        // Extract updated (new) editable region from the model response.
196        // The model may include editable region markers in its output, so we need to strip them.
197        let new_editable_region = extract_last_codeblock(response);
198
199        // Check if the model indicated no edits are needed
200        if new_editable_region.trim() == Self::NO_EDITS {
201            return Ok(String::new());
202        }
203
204        let mut new_editable_region = Self::extract_editable_region(&new_editable_region)?;
205        let old_editable_region = Self::extract_editable_region(
206            &example
207                .prompt
208                .as_ref()
209                .context("example prompt missing")?
210                .input,
211        )?;
212        let prompt_inputs = example
213            .prompt_inputs
214            .as_ref()
215            .context("example is missing prompt inputs")?;
216
217        // Normalize leading newlines: if old starts with newline but new doesn't,
218        // prepend newline to new to preserve whitespace structure.
219        // This handles the case where the model drops the leading blank line.
220        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
221            new_editable_region.insert(0, '\n');
222        }
223
224        let (editable_region_offset, _) = prompt_inputs
225            .content
226            .match_indices(&old_editable_region)
227            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
228            .context("editable region not found in prompt content")?;
229        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
230            .matches('\n')
231            .count();
232
233        let diff = language::unified_diff_with_offsets(
234            &old_editable_region,
235            &new_editable_region,
236            editable_region_start_line as u32,
237            editable_region_start_line as u32,
238        );
239
240        let diff = indoc::formatdoc! {"
241            --- a/{path}
242            +++ b/{path}
243            {diff}",
244            path = example.spec.cursor_path.to_string_lossy(),
245            diff = diff,
246        };
247
248        Ok(diff)
249    }
250
251    fn format_edit_history(edit_history: &str) -> String {
252        // Strip comments ("garbage lines") from edit history
253        let lines = edit_history
254            .lines()
255            .filter(|&s| Self::is_udiff_content_line(s))
256            .collect::<Vec<_>>();
257
258        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
259            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
260        } else {
261            &lines
262        };
263
264        if history_lines.is_empty() {
265            return "(No edit history)".to_string();
266        }
267
268        history_lines.join("\n")
269    }
270
271    pub fn format_context(example: &Example) -> String {
272        let related_files = example
273            .prompt_inputs
274            .as_ref()
275            .and_then(|pi| pi.related_files.as_ref());
276
277        let Some(related_files) = related_files else {
278            return "(No context)".to_string();
279        };
280
281        if related_files.is_empty() {
282            return "(No context)".to_string();
283        }
284
285        let mut prompt = String::new();
286        for file in related_files {
287            let path_str = file.path.to_string_lossy();
288            writeln!(&mut prompt, "`````{path_str}").ok();
289
290            let mut prev_row = 0;
291            for excerpt in &file.excerpts {
292                if excerpt.row_range.start > prev_row {
293                    prompt.push_str("\n");
294                }
295                prompt.push_str(&excerpt.text);
296                prompt.push('\n');
297                prev_row = excerpt.row_range.end;
298            }
299            if prev_row < file.max_row {
300                prompt.push_str("\n");
301            }
302            prompt.push_str("\n`````\n");
303        }
304
305        prompt
306    }
307
308    fn format_cursor_excerpt(
309        example: &Example,
310        editable_range: Range<usize>,
311        context_range: Range<usize>,
312    ) -> String {
313        let mut result = String::new();
314
315        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
316
317        let path_str = example.spec.cursor_path.to_string_lossy();
318        result.push_str(&format!("`````{path_str}\n"));
319        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
320        result.push_str(Self::EDITABLE_REGION_START);
321        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
322        result.push_str(Self::USER_CURSOR_MARKER);
323        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
324        result.push_str(Self::EDITABLE_REGION_END);
325        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
326        result.push_str("\n`````");
327
328        result
329    }
330
331    fn extract_editable_region(text: &str) -> Result<String> {
332        let start = text
333            .rfind(Self::EDITABLE_REGION_START)
334            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
335        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
336
337        if start >= end {
338            return Err(anyhow!("Invalid editable region markers"));
339        }
340
341        let region = &text[start..end];
342        let region = region.strip_suffix('\n').unwrap_or(region);
343
344        Ok(region.replace(Self::USER_CURSOR_MARKER, ""))
345    }
346
347    fn is_udiff_content_line(s: &str) -> bool {
348        s.starts_with("-")
349            || s.starts_with("+")
350            || s.starts_with(" ")
351            || s.starts_with("---")
352            || s.starts_with("+++")
353            || s.starts_with("@@")
354    }
355}
356
357/// Extract the cursor excerpt from an example.
358/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
359pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
360    // If we have the original prompt, extract the cursor excerpt from it
361    if let Some(prompt) = &example.prompt {
362        // Find "# 3. Current File" section and extract the content
363        if let Some(start) = prompt.input.find("# 3. Current File") {
364            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
365            let backtick_count = prompt.input[content_start..]
366                .chars()
367                .take_while(|&c| c == '`')
368                .count();
369            let content_start = content_start + backtick_count;
370
371            // Find the path line and skip it
372            let newline_pos = prompt.input[content_start..].find('\n')?;
373            let text_start = content_start + newline_pos + 1;
374
375            // Find the closing backticks
376            let closing_pattern = "`".repeat(backtick_count);
377            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
378            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
379
380            let path_str = example.spec.cursor_path.to_string_lossy();
381            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
382        }
383    }
384
385    // Fallback: construct from prompt_inputs if available
386    let prompt_inputs = example.prompt_inputs.as_ref()?;
387    let content = &prompt_inputs.content;
388    let cursor_offset = prompt_inputs.cursor_offset;
389
390    // Simple fallback: just show content around cursor with markers
391    let path_str = example.spec.cursor_path.to_string_lossy();
392    let mut result = format!("`````{path_str}\n");
393    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
394    result.push_str(&content[..cursor_offset]);
395    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
396    result.push_str(&content[cursor_offset..]);
397    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
398    result.push_str("\n`````");
399
400    Some(result)
401}
402
403fn extract_last_codeblock(text: &str) -> String {
404    let mut last_block = None;
405    let mut search_start = 0;
406
407    while let Some(start) = text[search_start..].find("```") {
408        let start = start + search_start;
409        let bytes = text.as_bytes();
410        let mut backtick_end = start;
411
412        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
413            backtick_end += 1;
414        }
415
416        let backtick_count = backtick_end - start;
417        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
418
419        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
420            backtick_end += 1;
421        }
422
423        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
424            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
425            last_block = Some(code_block.to_string());
426            search_start = backtick_end + end_pos + closing_pattern.len();
427        } else {
428            break;
429        }
430    }
431
432    last_block.unwrap_or_else(|| text.to_string())
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_extract_last_code_block() {
441        let text = indoc::indoc! {"
442            Some thinking
443
444            ```
445            first block
446            ```
447
448            `````path='something' lines=1:2
449            last block
450            `````
451            "};
452        let last_block = extract_last_codeblock(text);
453        assert_eq!(last_block, "last block\n");
454    }
455
456    #[test]
457    fn test_extract_codeblock_with_nested_fences() {
458        let text = indoc::indoc! {"
459            `````
460            content with ``` inline
461            and ```python nested
462            more content
463            `````
464            "};
465        let last_block = extract_last_codeblock(text);
466        assert_eq!(
467            last_block,
468            "content with ``` inline\nand ```python nested\nmore content\n"
469        );
470    }
471
472    #[test]
473    fn test_extract_codeblock_ignores_inline_backticks() {
474        let text = indoc::indoc! {"
475            `````
476            here is some `code` with inline backticks
477            and here```more```stuff
478            `````
479            "};
480        let last_block = extract_last_codeblock(text);
481        assert_eq!(
482            last_block,
483            "here is some `code` with inline backticks\nand here```more```stuff\n"
484        );
485    }
486
487    #[test]
488    fn test_extract_editable_region() {
489        let text = indoc::indoc! {"
490            some lines
491            are
492            here
493            <|editable_region_start|>
494            one
495            two three
496
497            <|editable_region_end|>
498            more
499            lines here
500            "};
501        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
502        assert_eq!(
503            parsed,
504            indoc::indoc! {"
505            one
506            two three"}
507        );
508    }
509
510    #[test]
511    fn test_extract_last_codeblock_nested_bibtex() {
512        let text = indoc::indoc! {r#"
513            Looking at the edit history, I can see that a Citation section was just added.
514
515            `````
516            ## Collaborations
517            Our mission is to create a 4D generative model.
518
519            ## Citation
520
521            If you found Unique3D helpful, please cite our report:
522            ```bibtex
523            @misc{wu2024unique3d,
524                  title={Unique3D},
525            }
526            ```
527            `````
528            "#};
529        let last_block = extract_last_codeblock(text);
530        assert_eq!(
531            last_block,
532            indoc::indoc! {r#"
533            ## Collaborations
534            Our mission is to create a 4D generative model.
535
536            ## Citation
537
538            If you found Unique3D helpful, please cite our report:
539            ```bibtex
540            @misc{wu2024unique3d,
541                  title={Unique3D},
542            }
543            ```
544            "#}
545        );
546    }
547
548    #[test]
549    fn test_extract_editable_region_no_markers() {
550        let text = indoc::indoc! {"
551            one
552            two three"};
553        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
554        assert_eq!(
555            parsed,
556            indoc::indoc! {"
557            one
558            two three"}
559        );
560    }
561
562    #[test]
563    fn test_parse_no_edits_response() {
564        let response = indoc::indoc! {"
565            The code is already complete. There is no clear next edit to make.
566
567            `````
568            NO_EDITS
569            `````
570        "};
571        let codeblock = extract_last_codeblock(response);
572        assert_eq!(codeblock.trim(), TeacherPrompt::NO_EDITS);
573    }
574
575    #[test]
576    fn test_extract_editable_region_strips_cursor_marker() {
577        let text = indoc::indoc! {"
578            <|editable_region_start|>
579            one
580            <|user_cursor|>two three
581
582            <|editable_region_end|>
583            "};
584        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
585        assert_eq!(
586            parsed,
587            indoc::indoc! {"
588            one
589            two three"}
590        );
591    }
592}