Detailed changes
@@ -13,7 +13,8 @@ use release_channel::AppVersion;
use std::env;
use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt};
+use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output};
+use zeta_prompt::{format_zeta_prompt, get_prefill};
pub const MAX_CONTEXT_TOKENS: usize = 350;
@@ -23,6 +24,7 @@ pub fn max_editable_tokens(format: ZetaFormat) -> usize {
ZetaFormat::V0114180EditableRegion => 180,
ZetaFormat::V0120GitMergeMarkers => 180,
ZetaFormat::V0131GitMergeMarkersPrefix => 180,
+ ZetaFormat::V0211Prefill => 180,
}
}
@@ -88,6 +90,8 @@ pub fn request_prediction_with_zeta2(
let (request_id, output_text, usage) = if let Some(config) = &raw_config {
let prompt = format_zeta_prompt(&prompt_input, config.format);
+ let prefill = get_prefill(&prompt_input, config.format);
+ let prompt = format!("{prompt}{prefill}");
let request = RawCompletionRequest {
model: config.model_id.clone().unwrap_or_default(),
prompt,
@@ -108,7 +112,9 @@ pub fn request_prediction_with_zeta2(
let request_id = EditPredictionId(response.id.clone().into());
let output_text = response.choices.pop().map(|choice| {
- clean_zeta2_model_output(&choice.text, config.format).to_string()
+ let response = &choice.text;
+ let output = format!("{prefill}{response}");
+ clean_zeta2_model_output(&output, config.format).to_string()
});
(request_id, output_text, usage)
@@ -76,6 +76,8 @@ pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
pub rejected_output: Option<String>, // For DPO
+ #[serde(default)]
+ pub prefill: Option<String>,
pub provider: PredictionProvider,
}
@@ -65,6 +65,7 @@ pub async fn run_format_prompt(
input: prompt,
expected_output: String::new(),
rejected_output: None,
+ prefill: None,
provider: args.provider,
});
}
@@ -94,6 +95,7 @@ pub async fn run_format_prompt(
related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
};
let prompt = format_zeta_prompt(&input, version);
+ let prefill = zeta_prompt::get_prefill(&input, version);
let (expected_patch, expected_cursor_offset) = example
.spec
.expected_patches_with_cursor_positions()
@@ -113,6 +115,7 @@ pub async fn run_format_prompt(
expected_output,
rejected_output,
provider: args.provider,
+ prefill: Some(prefill),
});
}
_ => {
@@ -55,7 +55,9 @@ fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<Stri
ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
("<|fim_middle|>current\n", "<|fim_suffix|>")
}
- ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => (
+ ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix
+ | ZetaFormat::V0211Prefill => (
zeta_prompt::v0120_git_merge_markers::START_MARKER,
zeta_prompt::v0120_git_merge_markers::SEPARATOR,
),
@@ -101,11 +103,13 @@ fn parse_zeta2_output(
};
let suffix = match format {
- ZetaFormat::V0131GitMergeMarkersPrefix => {
+ ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
}
ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
- _ => "",
+ ZetaFormat::V0112MiddleAtEnd
+ | ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion => "",
};
if !suffix.is_empty() {
new_text = new_text
@@ -159,6 +159,7 @@ pub async fn run_prediction(
expected_output: String::new(),
rejected_output: None,
provider,
+ prefill: None,
});
}
}
@@ -9,6 +9,11 @@ use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
pub const MAX_PROMPT_TOKENS: usize = 4096;
+/// Use up to this amount of the editable region for prefill.
+/// Larger values may result in more robust generation, but
+/// this region becomes non-editable.
+pub const PREFILL_RATIO: f64 = 0.1; // 10%
+
fn estimate_tokens(bytes: usize) -> usize {
bytes / 3
}
@@ -46,6 +51,7 @@ pub enum ZetaFormat {
V0114180EditableRegion,
V0120GitMergeMarkers,
V0131GitMergeMarkersPrefix,
+ V0211Prefill,
}
impl std::fmt::Display for ZetaFormat {
@@ -170,7 +176,7 @@ fn format_zeta_prompt_with_budget(
ZetaFormat::V0120GitMergeMarkers => {
v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
}
- ZetaFormat::V0131GitMergeMarkersPrefix => {
+ ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input)
}
}
@@ -193,6 +199,17 @@ fn format_zeta_prompt_with_budget(
prompt
}
+pub fn get_prefill(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd
+ | ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix => String::new(),
+ ZetaFormat::V0211Prefill => v0211_prefill::get_prefill(input),
+ }
+}
+
fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
let header = "<|file_sep|>edit history\n";
let header_tokens = estimate_tokens(header.len());
@@ -496,6 +513,41 @@ pub mod v0131_git_merge_markers_prefix {
}
}
+pub mod v0211_prefill {
+ use super::*;
+
+ pub fn get_prefill(input: &ZetaPromptInput) -> String {
+ let editable_region = &input.cursor_excerpt
+ [input.editable_range_in_excerpt.start..input.editable_range_in_excerpt.end];
+
+ let prefill_len = (editable_region.len() as f64 * PREFILL_RATIO) as usize;
+ let prefill_len = editable_region.floor_char_boundary(prefill_len);
+
+ // Find a token boundary to avoid splitting tokens in the prefill.
+ // In Qwen2.5-Coder, \n is always the END of a token (e.g. `;\n`,
+ // ` {\n`), and \n\n / \n\n\n are single tokens, so we must include
+ // the \n and consume any consecutive \n characters after it.
+ let prefill = &editable_region[..prefill_len];
+ match prefill.rfind('\n') {
+ Some(pos) => {
+ let mut end = pos + 1;
+ while end < editable_region.len()
+ && editable_region.as_bytes().get(end) == Some(&b'\n')
+ {
+ end += 1;
+ }
+ editable_region[..end].to_string()
+ }
+ // No newline found. Fall back to splitting before the last space
+ // (word-level boundary)
+ None => match prefill.rfind(' ') {
+ Some(pos) => prefill[..pos].to_string(),
+ None => prefill.to_string(),
+ },
+ }
+ }
+}
+
/// The zeta1 prompt format
pub mod zeta1 {
pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";