ep: Add option to use prompt prefill (#48964)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/zeta2.rs             | 10 ++
crates/edit_prediction_cli/src/example.rs       |  2 
crates/edit_prediction_cli/src/format_prompt.rs |  3 +
crates/edit_prediction_cli/src/parse_output.rs  | 10 ++-
crates/edit_prediction_cli/src/predict.rs       |  1 
crates/zeta_prompt/src/zeta_prompt.rs           | 54 ++++++++++++++++++
6 files changed, 74 insertions(+), 6 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta2.rs 🔗

@@ -13,7 +13,8 @@ use release_channel::AppVersion;
 
 use std::env;
 use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt};
+use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output};
+use zeta_prompt::{format_zeta_prompt, get_prefill};
 
 pub const MAX_CONTEXT_TOKENS: usize = 350;
 
@@ -23,6 +24,7 @@ pub fn max_editable_tokens(format: ZetaFormat) -> usize {
         ZetaFormat::V0114180EditableRegion => 180,
         ZetaFormat::V0120GitMergeMarkers => 180,
         ZetaFormat::V0131GitMergeMarkersPrefix => 180,
+        ZetaFormat::V0211Prefill => 180,
     }
 }
 
@@ -88,6 +90,8 @@ pub fn request_prediction_with_zeta2(
 
             let (request_id, output_text, usage) = if let Some(config) = &raw_config {
                 let prompt = format_zeta_prompt(&prompt_input, config.format);
+                let prefill = get_prefill(&prompt_input, config.format);
+                let prompt = format!("{prompt}{prefill}");
                 let request = RawCompletionRequest {
                     model: config.model_id.clone().unwrap_or_default(),
                     prompt,
@@ -108,7 +112,9 @@ pub fn request_prediction_with_zeta2(
 
                 let request_id = EditPredictionId(response.id.clone().into());
                 let output_text = response.choices.pop().map(|choice| {
-                    clean_zeta2_model_output(&choice.text, config.format).to_string()
+                    let response = &choice.text;
+                    let output = format!("{prefill}{response}");
+                    clean_zeta2_model_output(&output, config.format).to_string()
                 });
 
                 (request_id, output_text, usage)

crates/edit_prediction_cli/src/example.rs 🔗

@@ -76,6 +76,8 @@ pub struct ExamplePrompt {
     pub input: String,
     pub expected_output: String,
     pub rejected_output: Option<String>, // For DPO
+    #[serde(default)]
+    pub prefill: Option<String>,
     pub provider: PredictionProvider,
 }
 

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -65,6 +65,7 @@ pub async fn run_format_prompt(
                 input: prompt,
                 expected_output: String::new(),
                 rejected_output: None,
+                prefill: None,
                 provider: args.provider,
             });
         }
@@ -94,6 +95,7 @@ pub async fn run_format_prompt(
                 related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
             };
             let prompt = format_zeta_prompt(&input, version);
+            let prefill = zeta_prompt::get_prefill(&input, version);
             let (expected_patch, expected_cursor_offset) = example
                 .spec
                 .expected_patches_with_cursor_positions()
@@ -113,6 +115,7 @@ pub async fn run_format_prompt(
                 expected_output,
                 rejected_output,
                 provider: args.provider,
+                prefill: Some(prefill),
             });
         }
         _ => {

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -55,7 +55,9 @@ fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<Stri
         ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
             ("<|fim_middle|>current\n", "<|fim_suffix|>")
         }
-        ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => (
+        ZetaFormat::V0120GitMergeMarkers
+        | ZetaFormat::V0131GitMergeMarkersPrefix
+        | ZetaFormat::V0211Prefill => (
             zeta_prompt::v0120_git_merge_markers::START_MARKER,
             zeta_prompt::v0120_git_merge_markers::SEPARATOR,
         ),
@@ -101,11 +103,13 @@ fn parse_zeta2_output(
     };
 
     let suffix = match format {
-        ZetaFormat::V0131GitMergeMarkersPrefix => {
+        ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
             zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
         }
         ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
-        _ => "",
+        ZetaFormat::V0112MiddleAtEnd
+        | ZetaFormat::V0113Ordered
+        | ZetaFormat::V0114180EditableRegion => "",
     };
     if !suffix.is_empty() {
         new_text = new_text

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -9,6 +9,11 @@ use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
 pub const CURSOR_MARKER: &str = "<|user_cursor|>";
 pub const MAX_PROMPT_TOKENS: usize = 4096;
 
+/// Use up to this amount of the editable region for prefill.
+/// Larger values may result in more robust generation, but
+/// this region becomes non-editable.
+pub const PREFILL_RATIO: f64 = 0.1; // 10%
+
 fn estimate_tokens(bytes: usize) -> usize {
     bytes / 3
 }
@@ -46,6 +51,7 @@ pub enum ZetaFormat {
     V0114180EditableRegion,
     V0120GitMergeMarkers,
     V0131GitMergeMarkersPrefix,
+    V0211Prefill,
 }
 
 impl std::fmt::Display for ZetaFormat {
@@ -170,7 +176,7 @@ fn format_zeta_prompt_with_budget(
         ZetaFormat::V0120GitMergeMarkers => {
             v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
         }
-        ZetaFormat::V0131GitMergeMarkersPrefix => {
+        ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
             v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input)
         }
     }
@@ -193,6 +199,17 @@ fn format_zeta_prompt_with_budget(
     prompt
 }
 
+pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+    match format {
+        ZetaFormat::V0112MiddleAtEnd
+        | ZetaFormat::V0113Ordered
+        | ZetaFormat::V0114180EditableRegion
+        | ZetaFormat::V0120GitMergeMarkers
+        | ZetaFormat::V0131GitMergeMarkersPrefix => String::new(),
+        ZetaFormat::V0211Prefill => v0211_prefill::get_prefill(input),
+    }
+}
+
 fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
     let header = "<|file_sep|>edit history\n";
     let header_tokens = estimate_tokens(header.len());
@@ -496,6 +513,41 @@ pub mod v0131_git_merge_markers_prefix {
     }
 }
 
+pub mod v0211_prefill {
+    use super::*;
+
+    pub fn get_prefill(input: &ZetaPromptInput) -> String {
+        let editable_region = &input.cursor_excerpt
+            [input.editable_range_in_excerpt.start..input.editable_range_in_excerpt.end];
+
+        let prefill_len = (editable_region.len() as f64 * PREFILL_RATIO) as usize;
+        let prefill_len = editable_region.floor_char_boundary(prefill_len);
+
+        // Find a token boundary to avoid splitting tokens in the prefill.
+        // In Qwen2.5-Coder, \n is always the END of a token (e.g. `;\n`,
+        // ` {\n`), and \n\n / \n\n\n are single tokens, so we must include
+        // the \n and consume any consecutive \n characters after it.
+        let prefill = &editable_region[..prefill_len];
+        match prefill.rfind('\n') {
+            Some(pos) => {
+                let mut end = pos + 1;
+                while end < editable_region.len()
+                    && editable_region.as_bytes().get(end) == Some(&b'\n')
+                {
+                    end += 1;
+                }
+                editable_region[..end].to_string()
+            }
+            // No newline found. Fall back to splitting before the last space
+            // (word-level boundary)
+            None => match prefill.rfind(' ') {
+                Some(pos) => prefill[..pos].to_string(),
+                None => prefill.to_string(),
+            },
+        }
+    }
+}
+
 /// The zeta1 prompt format
 pub mod zeta1 {
     pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";