diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index b3102455d7d4ac9640307ed706ca4cacc14d8592..cbad42e609388396ba3276e95d3f04a6b03e2929 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -118,6 +118,7 @@ pub fn compute_edits_and_cursor_position( // new_offset = old_offset + delta, so old_offset = new_offset - delta let mut delta: isize = 0; let mut cursor_position: Option = None; + let buffer_len = snapshot.len(); let edits = diffs .iter() @@ -129,13 +130,15 @@ pub fn compute_edits_and_cursor_position( if cursor_offset < edit_start_in_new { let cursor_in_old = (cursor_offset as isize - delta) as usize; + let buffer_offset = (offset + cursor_in_old).min(buffer_len); cursor_position = Some(PredictedCursorPosition::at_anchor( - snapshot.anchor_after(offset + cursor_in_old), + snapshot.anchor_after(buffer_offset), )); } else if cursor_offset < edit_end_in_new { + let buffer_offset = (offset + raw_old_range.start).min(buffer_len); let offset_within_insertion = cursor_offset - edit_start_in_new; cursor_position = Some(PredictedCursorPosition::new( - snapshot.anchor_before(offset + raw_old_range.start), + snapshot.anchor_before(buffer_offset), offset_within_insertion, )); } @@ -158,6 +161,9 @@ pub fn compute_edits_and_cursor_position( old_range.start += prefix_len; old_range.end -= suffix_len; + old_range.start = old_range.start.min(buffer_len); + old_range.end = old_range.end.min(buffer_len); + let new_text = new_text[prefix_len..new_text.len() - suffix_len].into(); let range = if old_range.is_empty() { let anchor = snapshot.anchor_after(old_range.start); diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 874644b7605776364b3455092443263de05d84cd..c9a7847704ae5dc35116079868870bbdf4ee0fdd 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -16,7 +16,7 @@ use std::env; use std::{path::Path, sync::Arc, time::Instant}; use zeta_prompt::{ CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output, - format_zeta_prompt, get_prefill, + format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens, }; pub const MAX_CONTEXT_TOKENS: usize = 350; @@ -85,6 +85,10 @@ pub fn request_prediction_with_zeta2( is_open_source, ); + if prompt_input_contains_special_tokens(&prompt_input, zeta_version) { + return Ok((None, None)); + } + if let Some(debug_tx) = &debug_tx { let prompt = format_zeta_prompt(&prompt_input, zeta_version); debug_tx diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index fa6f7ce8f03bf7a9534017b99f503ebd6041f827..53de7b387ff6a92801e4482eef809f44a23ff7fa 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -126,6 +126,25 @@ impl ZetaFormat { .collect::>() .concat() } + + pub fn special_tokens(&self) -> &'static [&'static str] { + match self { + ZetaFormat::V0112MiddleAtEnd + | ZetaFormat::V0113Ordered + | ZetaFormat::V0114180EditableRegion => &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + CURSOR_MARKER, + ], + ZetaFormat::V0120GitMergeMarkers => v0120_git_merge_markers::special_tokens(), + ZetaFormat::V0131GitMergeMarkersPrefix | ZetaFormat::V0211Prefill => { + v0131_git_merge_markers_prefix::special_tokens() + } + ZetaFormat::V0211SeedCoder => seed_coder::special_tokens(), + } + } } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -194,6 +213,13 @@ pub struct RelatedExcerpt { pub text: Arc, } +pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: ZetaFormat) -> bool { + format + .special_tokens() + .iter() + .any(|token| input.cursor_excerpt.contains(token)) +} + pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String { format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS) } @@ -560,6 +586,19 @@ pub mod v0120_git_merge_markers { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; + pub fn special_tokens() -> &'static [&'static str] { + &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + START_MARKER, + SEPARATOR, + END_MARKER, + CURSOR_MARKER, + ] + } + pub fn write_cursor_excerpt_section( prompt: &mut String, path: &Path, @@ -621,6 +660,19 @@ pub mod v0131_git_merge_markers_prefix { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; + pub fn special_tokens() -> &'static [&'static str] { + &[ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|file_sep|>", + START_MARKER, + SEPARATOR, + END_MARKER, + CURSOR_MARKER, + ] + } + pub fn write_cursor_excerpt_section( prompt: &mut String, path: &Path, @@ -738,6 +790,19 @@ pub mod seed_coder { pub const SEPARATOR: &str = "=======\n"; pub const END_MARKER: &str = ">>>>>>> UPDATED\n"; + pub fn special_tokens() -> &'static [&'static str] { + &[ + FIM_SUFFIX, + FIM_PREFIX, + FIM_MIDDLE, + FILE_MARKER, + START_MARKER, + SEPARATOR, + END_MARKER, + CURSOR_MARKER, + ] + } + pub fn format_prompt_with_budget( path: &Path, context: &str,