ep: Don't run predictions for excerpts with special tokens (#49040)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/zeta1.rs   | 10 +++
crates/edit_prediction/src/zeta2.rs   |  6 ++
crates/zeta_prompt/src/zeta_prompt.rs | 65 +++++++++++++++++++++++++++++
3 files changed, 78 insertions(+), 3 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta1.rs 🔗

@@ -118,6 +118,7 @@ pub fn compute_edits_and_cursor_position(
     // new_offset = old_offset + delta, so old_offset = new_offset - delta
     let mut delta: isize = 0;
     let mut cursor_position: Option<PredictedCursorPosition> = None;
+    let buffer_len = snapshot.len();
 
     let edits = diffs
         .iter()
@@ -129,13 +130,15 @@ pub fn compute_edits_and_cursor_position(
 
                 if cursor_offset < edit_start_in_new {
                     let cursor_in_old = (cursor_offset as isize - delta) as usize;
+                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
                     cursor_position = Some(PredictedCursorPosition::at_anchor(
-                        snapshot.anchor_after(offset + cursor_in_old),
+                        snapshot.anchor_after(buffer_offset),
                     ));
                 } else if cursor_offset < edit_end_in_new {
+                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
                     let offset_within_insertion = cursor_offset - edit_start_in_new;
                     cursor_position = Some(PredictedCursorPosition::new(
-                        snapshot.anchor_before(offset + raw_old_range.start),
+                        snapshot.anchor_before(buffer_offset),
                         offset_within_insertion,
                     ));
                 }
@@ -158,6 +161,9 @@ pub fn compute_edits_and_cursor_position(
             old_range.start += prefix_len;
             old_range.end -= suffix_len;
 
+            old_range.start = old_range.start.min(buffer_len);
+            old_range.end = old_range.end.min(buffer_len);
+
             let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
             let range = if old_range.is_empty() {
                 let anchor = snapshot.anchor_after(old_range.start);

crates/edit_prediction/src/zeta2.rs 🔗

@@ -16,7 +16,7 @@ use std::env;
 use std::{path::Path, sync::Arc, time::Instant};
 use zeta_prompt::{
     CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
-    format_zeta_prompt, get_prefill,
+    format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
 };
 
 pub const MAX_CONTEXT_TOKENS: usize = 350;
@@ -85,6 +85,10 @@ pub fn request_prediction_with_zeta2(
                 is_open_source,
             );
 
+            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
+                return Ok((None, None));
+            }
+
             if let Some(debug_tx) = &debug_tx {
                 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
                 debug_tx

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -126,6 +126,25 @@ impl ZetaFormat {
             .collect::<Vec<_>>()
             .concat()
     }
+
+    pub fn special_tokens(&self) -> &'static [&'static str] {
+        match self {
+            ZetaFormat::V0112MiddleAtEnd
+            | ZetaFormat::V0113Ordered
+            | ZetaFormat::V0114180EditableRegion => &[
+                "<|fim_prefix|>",
+                "<|fim_suffix|>",
+                "<|fim_middle|>",
+                "<|file_sep|>",
+                CURSOR_MARKER,
+            ],
+            ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(),
+            ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => {
+                v0131_git_merge_markers_prefix::special_tokens()
+            }
+            ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(),
+        }
+    }
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
@@ -194,6 +213,13 @@ pub struct RelatedExcerpt {
     pub text: Arc<str>,
 }
 
+pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: ZetaFormat) -> bool {
+    format
+        .special_tokens()
+        .iter()
+        .any(|token| input.cursor_excerpt.contains(token))
+}
+
 pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String {
     format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS)
 }
@@ -560,6 +586,19 @@ pub mod v0120_git_merge_markers {
     pub const SEPARATOR: &str = "=======\n";
     pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
 
+    pub fn special_tokens() -> &'static [&'static str] {
+        &[
+            "<|fim_prefix|>",
+            "<|fim_suffix|>",
+            "<|fim_middle|>",
+            "<|file_sep|>",
+            START_MARKER,
+            SEPARATOR,
+            END_MARKER,
+            CURSOR_MARKER,
+        ]
+    }
+
     pub fn write_cursor_excerpt_section(
         prompt: &mut String,
         path: &Path,
@@ -621,6 +660,19 @@ pub mod v0131_git_merge_markers_prefix {
     pub const SEPARATOR: &str = "=======\n";
     pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
 
+    pub fn special_tokens() -> &'static [&'static str] {
+        &[
+            "<|fim_prefix|>",
+            "<|fim_suffix|>",
+            "<|fim_middle|>",
+            "<|file_sep|>",
+            START_MARKER,
+            SEPARATOR,
+            END_MARKER,
+            CURSOR_MARKER,
+        ]
+    }
+
     pub fn write_cursor_excerpt_section(
         prompt: &mut String,
         path: &Path,
@@ -738,6 +790,19 @@ pub mod seed_coder {
     pub const SEPARATOR: &str = "=======\n";
     pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
 
+    pub fn special_tokens() -> &'static [&'static str] {
+        &[
+            FIM_SUFFIX,
+            FIM_PREFIX,
+            FIM_MIDDLE,
+            FILE_MARKER,
+            START_MARKER,
+            SEPARATOR,
+            END_MARKER,
+            CURSOR_MARKER,
+        ]
+    }
+
     pub fn format_prompt_with_budget(
         path: &Path,
         context: &str,