format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result, anyhow};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaVersion;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    match args.provider {
 51        PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => {
 52            step_progress.set_substatus("formatting teacher prompt");
 53
 54            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 55                cursor_point,
 56                &snapshot,
 57                edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
 58                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 59            );
 60            let editable_range = editable_range.to_offset(&snapshot);
 61            let context_range = context_range.to_offset(&snapshot);
 62
 63            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 64            let expected_output = example
 65                .spec
 66                .expected_patches
 67                .first()
 68                .cloned()
 69                .unwrap_or_default();
 70            let rejected_output = example.spec.rejected_patch.clone();
 71
 72            example.prompt = Some(ExamplePrompt {
 73                input: prompt,
 74                expected_output,
 75                rejected_output,
 76                provider: args.provider,
 77            });
 78        }
 79        PredictionProvider::Zeta2(version) => {
 80            step_progress.set_substatus("formatting zeta2 prompt");
 81
 82            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 83                cursor_point,
 84                &snapshot,
 85                edit_prediction::zeta2::max_editable_tokens(version),
 86                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 87            );
 88            let editable_range = editable_range.to_offset(&snapshot);
 89            let context_range = context_range.to_offset(&snapshot);
 90
 91            let context_start = context_range.start;
 92            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 93            let editable_range_in_excerpt =
 94                (editable_range.start - context_start)..(editable_range.end - context_start);
 95            let input = zeta_prompt::ZetaPromptInput {
 96                cursor_path: example.spec.cursor_path.clone(),
 97                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 98                editable_range_in_excerpt,
 99                cursor_offset_in_excerpt,
100                events: prompt_inputs.edit_history.clone(),
101                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
102            };
103            let prompt = format_zeta_prompt(&input, version);
104            let expected_output = zeta2_output_for_patch(
105                &input,
106                &example
107                    .spec
108                    .expected_patches
109                    .first()
110                    .context("expected patches is empty")?
111                    .clone(),
112                version,
113            )?;
114            let rejected_output = example
115                .spec
116                .rejected_patch
117                .as_ref()
118                .and_then(|patch| zeta2_output_for_patch(&input, &patch, version).ok());
119
120            example.prompt = Some(ExamplePrompt {
121                input: prompt,
122                expected_output,
123                rejected_output,
124                provider: args.provider,
125            });
126        }
127        _ => {
128            panic!("Cannot format prompt for {:?}", args.provider);
129        }
130    };
131    Ok(())
132}
133
134pub fn zeta2_output_for_patch(
135    input: &zeta_prompt::ZetaPromptInput,
136    patch: &str,
137    version: ZetaVersion,
138) -> Result<String> {
139    let mut old_editable_region =
140        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
141
142    if !old_editable_region.ends_with_newline() {
143        old_editable_region.push('\n');
144    }
145
146    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
147        .with_context(|| {
148            format!(
149                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
150                patch, old_editable_region
151            )
152        })?;
153
154    if version == ZetaVersion::V0120GitMergeMarkers {
155        if !result.ends_with('\n') {
156            result.push('\n');
157        }
158        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
159    }
160
161    Ok(result)
162}
163
164pub struct TeacherPrompt;
165
166impl TeacherPrompt {
167    const PROMPT: &str = include_str!("prompts/teacher.md");
168    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
169    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
170    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
171
172    /// Truncate edit history to this number of last lines
173    const MAX_HISTORY_LINES: usize = 128;
174
175    pub fn format_prompt(
176        example: &Example,
177        editable_range: Range<usize>,
178        context_range: Range<usize>,
179    ) -> String {
180        let edit_history = Self::format_edit_history(&example.spec.edit_history);
181        let context = Self::format_context(example);
182        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
183
184        let prompt = Self::PROMPT
185            .replace("{{context}}", &context)
186            .replace("{{edit_history}}", &edit_history)
187            .replace("{{cursor_excerpt}}", &cursor_excerpt);
188
189        prompt
190    }
191
192    pub fn parse(example: &Example, response: &str) -> Result<String> {
193        // Extract updated (new) editable region from the model response.
194        // The model may include editable region markers in its output, so we need to strip them.
195        let new_editable_region = extract_last_codeblock(response);
196        let mut new_editable_region = Self::extract_editable_region(&new_editable_region)?;
197        let old_editable_region = Self::extract_editable_region(
198            &example
199                .prompt
200                .as_ref()
201                .context("example prompt missing")?
202                .input,
203        )?;
204        let prompt_inputs = example
205            .prompt_inputs
206            .as_ref()
207            .context("example is missing prompt inputs")?;
208
209        // Normalize leading newlines: if old starts with newline but new doesn't,
210        // prepend newline to new to preserve whitespace structure.
211        // This handles the case where the model drops the leading blank line.
212        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
213            new_editable_region.insert(0, '\n');
214        }
215
216        let (editable_region_offset, _) = prompt_inputs
217            .content
218            .match_indices(&old_editable_region)
219            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
220            .context("editable region not found in prompt content")?;
221        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
222            .matches('\n')
223            .count();
224
225        let diff = language::unified_diff_with_offsets(
226            &old_editable_region,
227            &new_editable_region,
228            editable_region_start_line as u32,
229            editable_region_start_line as u32,
230        );
231
232        let diff = indoc::formatdoc! {"
233            --- a/{path}
234            +++ b/{path}
235            {diff}",
236            path = example.spec.cursor_path.to_string_lossy(),
237            diff = diff,
238        };
239
240        Ok(diff)
241    }
242
243    fn format_edit_history(edit_history: &str) -> String {
244        // Strip comments ("garbage lines") from edit history
245        let lines = edit_history
246            .lines()
247            .filter(|&s| Self::is_udiff_content_line(s))
248            .collect::<Vec<_>>();
249
250        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
251            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
252        } else {
253            &lines
254        };
255
256        if history_lines.is_empty() {
257            return "(No edit history)".to_string();
258        }
259
260        history_lines.join("\n")
261    }
262
263    pub fn format_context(example: &Example) -> String {
264        let related_files = example
265            .prompt_inputs
266            .as_ref()
267            .and_then(|pi| pi.related_files.as_ref());
268
269        let Some(related_files) = related_files else {
270            return "(No context)".to_string();
271        };
272
273        if related_files.is_empty() {
274            return "(No context)".to_string();
275        }
276
277        let mut prompt = String::new();
278        for file in related_files {
279            let path_str = file.path.to_string_lossy();
280            writeln!(&mut prompt, "`````{path_str}").ok();
281
282            let mut prev_row = 0;
283            for excerpt in &file.excerpts {
284                if excerpt.row_range.start > prev_row {
285                    prompt.push_str("\n");
286                }
287                prompt.push_str(&excerpt.text);
288                prompt.push('\n');
289                prev_row = excerpt.row_range.end;
290            }
291            if prev_row < file.max_row {
292                prompt.push_str("\n");
293            }
294            prompt.push_str("\n`````\n");
295        }
296
297        prompt
298    }
299
300    fn format_cursor_excerpt(
301        example: &Example,
302        editable_range: Range<usize>,
303        context_range: Range<usize>,
304    ) -> String {
305        let mut result = String::new();
306
307        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
308
309        let path_str = example.spec.cursor_path.to_string_lossy();
310        result.push_str(&format!("`````{path_str}\n"));
311        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
312        result.push_str(Self::EDITABLE_REGION_START);
313        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
314        result.push_str(Self::USER_CURSOR_MARKER);
315        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
316        result.push_str(Self::EDITABLE_REGION_END);
317        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
318        result.push_str("\n`````");
319
320        result
321    }
322
323    fn extract_editable_region(text: &str) -> Result<String> {
324        let start = text
325            .rfind(Self::EDITABLE_REGION_START)
326            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
327        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
328
329        if start >= end {
330            return Err(anyhow!("Invalid editable region markers"));
331        }
332
333        let region = &text[start..end];
334        let region = region.strip_suffix('\n').unwrap_or(region);
335
336        Ok(region.replace(Self::USER_CURSOR_MARKER, ""))
337    }
338
339    fn is_udiff_content_line(s: &str) -> bool {
340        s.starts_with("-")
341            || s.starts_with("+")
342            || s.starts_with(" ")
343            || s.starts_with("---")
344            || s.starts_with("+++")
345            || s.starts_with("@@")
346    }
347}
348
349/// Extract the cursor excerpt from an example.
350/// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs.
351pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option<String> {
352    // If we have the original prompt, extract the cursor excerpt from it
353    if let Some(prompt) = &example.prompt {
354        // Find "# 3. Current File" section and extract the content
355        if let Some(start) = prompt.input.find("# 3. Current File") {
356            let content_start = prompt.input[start..].find('`').map(|i| start + i)?;
357            let backtick_count = prompt.input[content_start..]
358                .chars()
359                .take_while(|&c| c == '`')
360                .count();
361            let content_start = content_start + backtick_count;
362
363            // Find the path line and skip it
364            let newline_pos = prompt.input[content_start..].find('\n')?;
365            let text_start = content_start + newline_pos + 1;
366
367            // Find the closing backticks
368            let closing_pattern = "`".repeat(backtick_count);
369            let text_end = prompt.input[text_start..].find(&closing_pattern)?;
370            let cursor_excerpt = &prompt.input[text_start..text_start + text_end];
371
372            let path_str = example.spec.cursor_path.to_string_lossy();
373            return Some(format!("`````{path_str}\n{cursor_excerpt}`````"));
374        }
375    }
376
377    // Fallback: construct from prompt_inputs if available
378    let prompt_inputs = example.prompt_inputs.as_ref()?;
379    let content = &prompt_inputs.content;
380    let cursor_offset = prompt_inputs.cursor_offset;
381
382    // Simple fallback: just show content around cursor with markers
383    let path_str = example.spec.cursor_path.to_string_lossy();
384    let mut result = format!("`````{path_str}\n");
385    result.push_str(TeacherPrompt::EDITABLE_REGION_START);
386    result.push_str(&content[..cursor_offset]);
387    result.push_str(TeacherPrompt::USER_CURSOR_MARKER);
388    result.push_str(&content[cursor_offset..]);
389    result.push_str(TeacherPrompt::EDITABLE_REGION_END);
390    result.push_str("\n`````");
391
392    Some(result)
393}
394
395fn extract_last_codeblock(text: &str) -> String {
396    let mut last_block = None;
397    let mut search_start = 0;
398
399    while let Some(start) = text[search_start..].find("```") {
400        let start = start + search_start;
401        let bytes = text.as_bytes();
402        let mut backtick_end = start;
403
404        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
405            backtick_end += 1;
406        }
407
408        let backtick_count = backtick_end - start;
409        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
410
411        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
412            backtick_end += 1;
413        }
414
415        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
416            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
417            last_block = Some(code_block.to_string());
418            search_start = backtick_end + end_pos + closing_pattern.len();
419        } else {
420            break;
421        }
422    }
423
424    last_block.unwrap_or_else(|| text.to_string())
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_extract_last_code_block() {
433        let text = indoc::indoc! {"
434            Some thinking
435
436            ```
437            first block
438            ```
439
440            `````path='something' lines=1:2
441            last block
442            `````
443            "};
444        let last_block = extract_last_codeblock(text);
445        assert_eq!(last_block, "last block\n");
446    }
447
448    #[test]
449    fn test_extract_codeblock_with_nested_fences() {
450        let text = indoc::indoc! {"
451            `````
452            content with ``` inline
453            and ```python nested
454            more content
455            `````
456            "};
457        let last_block = extract_last_codeblock(text);
458        assert_eq!(
459            last_block,
460            "content with ``` inline\nand ```python nested\nmore content\n"
461        );
462    }
463
464    #[test]
465    fn test_extract_codeblock_ignores_inline_backticks() {
466        let text = indoc::indoc! {"
467            `````
468            here is some `code` with inline backticks
469            and here```more```stuff
470            `````
471            "};
472        let last_block = extract_last_codeblock(text);
473        assert_eq!(
474            last_block,
475            "here is some `code` with inline backticks\nand here```more```stuff\n"
476        );
477    }
478
479    #[test]
480    fn test_extract_editable_region() {
481        let text = indoc::indoc! {"
482            some lines
483            are
484            here
485            <|editable_region_start|>
486            one
487            two three
488
489            <|editable_region_end|>
490            more
491            lines here
492            "};
493        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
494        assert_eq!(
495            parsed,
496            indoc::indoc! {"
497            one
498            two three"}
499        );
500    }
501
502    #[test]
503    fn test_extract_last_codeblock_nested_bibtex() {
504        let text = indoc::indoc! {r#"
505            Looking at the edit history, I can see that a Citation section was just added.
506
507            `````
508            ## Collaborations
509            Our mission is to create a 4D generative model.
510
511            ## Citation
512
513            If you found Unique3D helpful, please cite our report:
514            ```bibtex
515            @misc{wu2024unique3d,
516                  title={Unique3D},
517            }
518            ```
519            `````
520            "#};
521        let last_block = extract_last_codeblock(text);
522        assert_eq!(
523            last_block,
524            indoc::indoc! {r#"
525            ## Collaborations
526            Our mission is to create a 4D generative model.
527
528            ## Citation
529
530            If you found Unique3D helpful, please cite our report:
531            ```bibtex
532            @misc{wu2024unique3d,
533                  title={Unique3D},
534            }
535            ```
536            "#}
537        );
538    }
539
540    #[test]
541    fn test_extract_editable_region_no_markers() {
542        let text = indoc::indoc! {"
543            one
544            two three"};
545        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
546        assert_eq!(
547            parsed,
548            indoc::indoc! {"
549            one
550            two three"}
551        );
552    }
553
554    #[test]
555    fn test_extract_editable_region_strips_cursor_marker() {
556        let text = indoc::indoc! {"
557            <|editable_region_start|>
558            one
559            <|user_cursor|>two three
560
561            <|editable_region_end|>
562            "};
563        let parsed = TeacherPrompt::extract_editable_region(text).unwrap();
564        assert_eq!(
565            parsed,
566            indoc::indoc! {"
567            one
568            two three"}
569        );
570    }
571}