format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaVersion;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    match args.provider {
 51        PredictionProvider::Teacher(_)
 52        | PredictionProvider::TeacherNonBatching(_)
 53        | PredictionProvider::RepairedTeacher(_) => {
 54            step_progress.set_substatus("formatting teacher prompt");
 55
 56            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 57                cursor_point,
 58                &snapshot,
 59                edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
 60                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 61            );
 62            let editable_range = editable_range.to_offset(&snapshot);
 63            let context_range = context_range.to_offset(&snapshot);
 64
 65            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 66            let expected_output = example
 67                .spec
 68                .expected_patches
 69                .first()
 70                .cloned()
 71                .unwrap_or_default();
 72            let rejected_output = example.spec.rejected_patch.clone();
 73
 74            example.prompt = Some(ExamplePrompt {
 75                input: prompt,
 76                expected_output,
 77                rejected_output,
 78                provider: args.provider,
 79            });
 80        }
 81        PredictionProvider::Zeta2(version) => {
 82            step_progress.set_substatus("formatting zeta2 prompt");
 83
 84            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 85                cursor_point,
 86                &snapshot,
 87                edit_prediction::zeta2::max_editable_tokens(version),
 88                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 89            );
 90            let editable_range = editable_range.to_offset(&snapshot);
 91            let context_range = context_range.to_offset(&snapshot);
 92
 93            let context_start = context_range.start;
 94            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 95            let editable_range_in_excerpt =
 96                (editable_range.start - context_start)..(editable_range.end - context_start);
 97            let input = zeta_prompt::ZetaPromptInput {
 98                cursor_path: example.spec.cursor_path.clone(),
 99                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
100                editable_range_in_excerpt,
101                cursor_offset_in_excerpt,
102                events: prompt_inputs.edit_history.clone(),
103                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
104            };
105            let prompt = format_zeta_prompt(&input, version);
106            let expected_output = zeta2_output_for_patch(
107                &input,
108                &example
109                    .spec
110                    .expected_patches
111                    .first()
112                    .context("expected patches is empty")?
113                    .clone(),
114                version,
115            )?;
116            let rejected_output = example
117                .spec
118                .rejected_patch
119                .as_ref()
120                .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
121
122            example.prompt = Some(ExamplePrompt {
123                input: prompt,
124                expected_output,
125                rejected_output,
126                provider: args.provider,
127            });
128        }
129        _ => {
130            panic!("Cannot format prompt for {:?}", args.provider);
131        }
132    };
133    Ok(())
134}
135
136pub fn zeta2_output_for_patch(
137    input: &zeta_prompt::ZetaPromptInput,
138    patch: &str,
139    version: ZetaVersion,
140) -> Result<String> {
141    let mut old_editable_region =
142        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
143
144    if !old_editable_region.ends_with_newline() {
145        old_editable_region.push('\n');
146    }
147
148    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
149        .with_context(|| {
150            format!(
151                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
152                patch, old_editable_region
153            )
154        })?;
155
156    if version == ZetaVersion::V0120GitMergeMarkers {
157        if !result.ends_with('\n') {
158            result.push('\n');
159        }
160        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
161    }
162
163    Ok(result)
164}
165
166pub struct TeacherPrompt;
167
168impl TeacherPrompt {
169    const PROMPT: &str = include_str!("prompts/teacher.md");
170    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
171    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
172    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
173
174    /// Truncate edit history to this number of last lines
175    const MAX_HISTORY_LINES: usize = 128;
176
177    pub fn format_prompt(
178        example: &Example,
179        editable_range: Range<usize>,
180        context_range: Range<usize>,
181    ) -> String {
182        let edit_history = Self::format_edit_history(&example.spec.edit_history);
183        let context = Self::format_context(example);
184        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
185
186        let prompt = Self::PROMPT
187            .replace("{{context}}", &context)
188            .replace("{{edit_history}}", &edit_history)
189            .replace("{{cursor_excerpt}}", &cursor_excerpt);
190
191        prompt
192    }
193
194    pub fn parse(example: &Example, response: &str) -> Result<String> {
195        // Extract updated (new) editable region from the model response.
196        // The model may include editable region markers in its output, so we need to strip them.
197        let new_editable_region = extract_last_codeblock(response);
198        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
199        let old_editable_region = Self::extract_editable_region(
200            &example
201                .prompt
202                .as_ref()
203                .context("example prompt missing")?
204                .input,
205        );
206        let prompt_inputs = example
207            .prompt_inputs
208            .as_ref()
209            .context("example is missing prompt inputs")?;
210
211        // Normalize leading newlines: if old starts with newline but new doesn't,
212        // prepend newline to new to preserve whitespace structure.
213        // This handles the case where the model drops the leading blank line.
214        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
215            new_editable_region.insert(0, '\n');
216        }
217
218        let (editable_region_offset, _) = prompt_inputs
219            .content
220            .match_indices(&old_editable_region)
221            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
222            .unwrap();
223        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
224            .matches('\n')
225            .count();
226
227        let diff = language::unified_diff_with_offsets(
228            &old_editable_region,
229            &new_editable_region,
230            editable_region_start_line as u32,
231            editable_region_start_line as u32,
232        );
233
234        let diff = indoc::formatdoc! {"
235            --- a/{path}
236            +++ b/{path}
237            {diff}",
238            path = example.spec.cursor_path.to_string_lossy(),
239            diff = diff,
240        };
241
242        Ok(diff)
243    }
244
245    fn format_edit_history(edit_history: &str) -> String {
246        // Strip comments ("garbage lines") from edit history
247        let lines = edit_history
248            .lines()
249            .filter(|&s| Self::is_udiff_content_line(s))
250            .collect::<Vec<_>>();
251
252        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
253            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
254        } else {
255            &lines
256        };
257
258        if history_lines.is_empty() {
259            return "(No edit history)".to_string();
260        }
261
262        history_lines.join("\n")
263    }
264
265    pub fn format_context(example: &Example) -> String {
266        let related_files = example
267            .prompt_inputs
268            .as_ref()
269            .and_then(|pi| pi.related_files.as_ref());
270
271        let Some(related_files) = related_files else {
272            return "(No context)".to_string();
273        };
274
275        if related_files.is_empty() {
276            return "(No context)".to_string();
277        }
278
279        let mut prompt = String::new();
280        for file in related_files {
281            let path_str = file.path.to_string_lossy();
282            writeln!(&mut prompt, "`````{path_str}").ok();
283
284            let mut prev_row = 0;
285            for excerpt in &file.excerpts {
286                if excerpt.row_range.start > prev_row {
287                    prompt.push_str("\n");
288                }
289                prompt.push_str(&excerpt.text);
290                prompt.push('\n');
291                prev_row = excerpt.row_range.end;
292            }
293            if prev_row < file.max_row {
294                prompt.push_str("\n");
295            }
296            prompt.push_str("\n`````\n");
297        }
298
299        prompt
300    }
301
302    fn format_cursor_excerpt(
303        example: &Example,
304        editable_range: Range<usize>,
305        context_range: Range<usize>,
306    ) -> String {
307        let mut result = String::new();
308
309        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
310
311        let path_str = example.spec.cursor_path.to_string_lossy();
312        result.push_str(&format!("`````{path_str}\n"));
313        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
314        result.push_str(Self::EDITABLE_REGION_START);
315        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
316        result.push_str(Self::USER_CURSOR_MARKER);
317        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
318        result.push_str(Self::EDITABLE_REGION_END);
319        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
320        result.push_str("\n`````");
321
322        result
323    }
324
325    fn extract_editable_region(text: &str) -> String {
326        let start = text
327            .rfind(Self::EDITABLE_REGION_START)
328            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
329        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
330
331        let region = &text[start..end];
332        let region = region.strip_suffix('\n').unwrap_or(region);
333
334        region.replace(Self::USER_CURSOR_MARKER, "")
335    }
336
337    fn is_udiff_content_line(s: &str) -> bool {
338        s.starts_with("-")
339            || s.starts_with("+")
340            || s.starts_with(" ")
341            || s.starts_with("---")
342            || s.starts_with("+++")
343            || s.starts_with("@@")
344    }
345}
346
347/// Extract the cursor excerpt from an example.
348/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
349pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
350    // If we have the original prompt, extract the cursor excerpt from it
351    if let Some(prompt) = &example.prompt {
352        // Find "# 3. Current File" section and extract the content
353        if let Some(start) = prompt.input.find("# 3. Current File") {
354            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
355            let backtick_count = prompt.input[content_start..]
356                .chars()
357                .take_while(|&c| c == '`')
358                .count();
359            let content_start = content_start + backtick_count;
360
361            // Find the path line and skip it
362            let newline_pos = prompt.input[content_start..].find('\n')?;
363            let text_start = content_start + newline_pos + 1;
364
365            // Find the closing backticks
366            let closing_pattern = "`".repeat(backtick_count);
367            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
368            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
369
370            let path_str = example.spec.cursor_path.to_string_lossy();
371            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
372        }
373    }
374
375    // Fallback: construct from prompt_inputs if available
376    let prompt_inputs = example.prompt_inputs.as_ref()?;
377    let content = &prompt_inputs.content;
378    let cursor_offset = prompt_inputs.cursor_offset;
379
380    // Simple fallback: just show content around cursor with markers
381    let path_str = example.spec.cursor_path.to_string_lossy();
382    let mut result = format!("`````{path_str}\n");
383    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
384    result.push_str(&content[..cursor_offset]);
385    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
386    result.push_str(&content[cursor_offset..]);
387    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
388    result.push_str("\n`````");
389
390    Some(result)
391}
392
393fn extract_last_codeblock(text: &str) -> String {
394    let mut last_block = None;
395    let mut search_start = 0;
396
397    while let Some(start) = text[search_start..].find("```") {
398        let start = start + search_start;
399        let bytes = text.as_bytes();
400        let mut backtick_end = start;
401
402        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
403            backtick_end += 1;
404        }
405
406        let backtick_count = backtick_end - start;
407        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
408
409        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
410            backtick_end += 1;
411        }
412
413        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
414            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
415            last_block = Some(code_block.to_string());
416            search_start = backtick_end + end_pos + closing_pattern.len();
417        } else {
418            break;
419        }
420    }
421
422    last_block.unwrap_or_else(|| text.to_string())
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_extract_last_code_block() {
431        let text = indoc::indoc! {"
432            Some thinking
433
434            ```
435            first block
436            ```
437
438            `````path='something' lines=1:2
439            last block
440            `````
441            "};
442        let last_block = extract_last_codeblock(text);
443        assert_eq!(last_block, "last block\n");
444    }
445
446    #[test]
447    fn test_extract_codeblock_with_nested_fences() {
448        let text = indoc::indoc! {"
449            `````
450            content with ``` inline
451            and ```python nested
452            more content
453            `````
454            "};
455        let last_block = extract_last_codeblock(text);
456        assert_eq!(
457            last_block,
458            "content with ``` inline\nand ```python nested\nmore content\n"
459        );
460    }
461
462    #[test]
463    fn test_extract_codeblock_ignores_inline_backticks() {
464        let text = indoc::indoc! {"
465            `````
466            here is some `code` with inline backticks
467            and here```more```stuff
468            `````
469            "};
470        let last_block = extract_last_codeblock(text);
471        assert_eq!(
472            last_block,
473            "here is some `code` with inline backticks\nand here```more```stuff\n"
474        );
475    }
476
477    #[test]
478    fn test_extract_editable_region() {
479        let text = indoc::indoc! {"
480            some lines
481            are
482            here
483            <|editable_region_start|>
484            one
485            two three
486
487            <|editable_region_end|>
488            more
489            lines here
490            "};
491        let parsed = TeacherPrompt::extract_editable_region(text);
492        assert_eq!(
493            parsed,
494            indoc::indoc! {"
495            one
496            two three"}
497        );
498    }
499
500    #[test]
501    fn test_extract_last_codeblock_nested_bibtex() {
502        let text = indoc::indoc! {r#"
503            Looking at the edit history, I can see that a Citation section was just added.
504
505            `````
506            ## Collaborations
507            Our mission is to create a 4D generative model.
508
509            ## Citation
510
511            If you found Unique3D helpful, please cite our report:
512            ```bibtex
513            @misc{wu2024unique3d,
514                  title={Unique3D},
515            }
516            ```
517            `````
518            "#};
519        let last_block = extract_last_codeblock(text);
520        assert_eq!(
521            last_block,
522            indoc::indoc! {r#"
523            ## Collaborations
524            Our mission is to create a 4D generative model.
525
526            ## Citation
527
528            If you found Unique3D helpful, please cite our report:
529            ```bibtex
530            @misc{wu2024unique3d,
531                  title={Unique3D},
532            }
533            ```
534            "#}
535        );
536    }
537
538    #[test]
539    fn test_extract_editable_region_no_markers() {
540        let text = indoc::indoc! {"
541            one
542            two three"};
543        let parsed = TeacherPrompt::extract_editable_region(text);
544        assert_eq!(
545            parsed,
546            indoc::indoc! {"
547            one
548            two three"}
549        );
550    }
551
552    #[test]
553    fn test_extract_editable_region_strips_cursor_marker() {
554        let text = indoc::indoc! {"
555            <|editable_region_start|>
556            one
557            <|user_cursor|>two three
558
559            <|editable_region_end|>
560            "};
561        let parsed = TeacherPrompt::extract_editable_region(text);
562        assert_eq!(
563            parsed,
564            indoc::indoc! {"
565            one
566            two three"}
567        );
568    }
569}