From 8a7daea560ac589e8627e61998af239943b2f7f5 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Sat, 14 Mar 2026 12:53:28 -0500 Subject: [PATCH] ep: Ensure prompt is always within token limit (#51529) Release Notes: - N/A *or* Added/Fixed/Improved ... --- .../src/edit_prediction_tests.rs | 1 + crates/edit_prediction/src/zeta.rs | 40 +++-- .../edit_prediction_cli/src/format_prompt.rs | 2 +- crates/zeta_prompt/src/zeta_prompt.rs | 154 ++++++++++-------- 4 files changed, 113 insertions(+), 84 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index dc52ef6ab57428d6293cea126c695f7c659e2f53..74688f64effc4c4e371d4516b25c6ce55b317dbb 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -2270,6 +2270,7 @@ fn empty_response() -> PredictEditsV3Response { fn prompt_from_request(request: &PredictEditsV3Request) -> String { zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default()) + .expect("default zeta prompt formatting should succeed in edit prediction tests") } fn assert_no_predict_request_ready( diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index fa93e681b66cb44a554f725d4a1c6dee11f0b1f1..fc3ed81c78737f4ba4c8b7aa5131232b2b007b87 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -130,13 +130,14 @@ pub fn request_prediction_with_zeta( return Err(anyhow::anyhow!("prompt contains special tokens")); } + let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version); + if let Some(debug_tx) = &debug_tx { - let prompt = format_zeta_prompt(&prompt_input, zeta_version); debug_tx .unbounded_send(DebugEvent::EditPredictionStarted( EditPredictionStartedDebugEvent { buffer: buffer.downgrade(), - prompt: Some(prompt), + prompt: formatted_prompt.clone(), position, }, )) @@ -145,11 +146,11 @@ pub fn request_prediction_with_zeta( log::trace!("Sending edit prediction request"); - let (request_id, output, model_version, usage) = - if let Some(custom_settings) = &custom_server_settings { + let Some((request_id, output, model_version, usage)) = + (if let Some(custom_settings) = &custom_server_settings { let max_tokens = custom_settings.max_output_tokens * 4; - match custom_settings.prompt_format { + Some(match custom_settings.prompt_format { EditPredictionPromptFormat::Zeta => { let ranges = &prompt_input.excerpt_ranges; let editable_range_in_excerpt = ranges.editable_350.clone(); @@ -186,7 +187,9 @@ pub fn request_prediction_with_zeta( (request_id, parsed_output, None, None) } EditPredictionPromptFormat::Zeta2 => { - let prompt = format_zeta_prompt(&prompt_input, zeta_version); + let Some(prompt) = formatted_prompt.clone() else { + return Ok((None, None)); + }; let prefill = get_prefill(&prompt_input, zeta_version); let prompt = format!("{prompt}{prefill}"); @@ -219,9 +222,11 @@ pub fn request_prediction_with_zeta( (request_id, output_text, None, None) } _ => anyhow::bail!("unsupported prompt format"), - } + }) } else if let Some(config) = &raw_config { - let prompt = format_zeta_prompt(&prompt_input, config.format); + let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else { + return Ok((None, None)); + }; let prefill = get_prefill(&prompt_input, config.format); let prompt = format!("{prompt}{prefill}"); let environment = config @@ -263,7 +268,7 @@ pub fn request_prediction_with_zeta( None }; - (request_id, output, None, usage) + Some((request_id, output, None, usage)) } else { // Use V3 endpoint - server handles model/version selection and suffix stripping let (response, usage) = EditPredictionStore::send_v3_request( @@ -284,8 +289,11 @@ pub fn request_prediction_with_zeta( range_in_excerpt: response.editable_range, }; - (request_id, Some(parsed_output), model_version, usage) - }; + Some((request_id, Some(parsed_output), model_version, usage)) + }) + else { + return Ok((None, None)); + }; let received_response_at = Instant::now(); @@ -296,7 +304,7 @@ pub fn request_prediction_with_zeta( range_in_excerpt: editable_range_in_excerpt, }) = output else { - return Ok(((request_id, None), None)); + return Ok((Some((request_id, None)), None)); }; let editable_range_in_buffer = editable_range_in_excerpt.start @@ -342,7 +350,7 @@ pub fn request_prediction_with_zeta( ); anyhow::Ok(( - ( + Some(( request_id, Some(Prediction { prompt_input, @@ -354,14 +362,16 @@ pub fn request_prediction_with_zeta( editable_range_in_buffer, model_version, }), - ), + )), usage, )) } }); cx.spawn(async move |this, cx| { - let (id, prediction) = handle_api_response(&this, request_task.await, cx)?; + let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else { + return Ok(None); + }; let Some(Prediction { prompt_input: inputs, diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index af955a05dce01fd34c37eb55d15b76b4a4592745..3a20fe0e9a5f89fa3325c1972721a836d60f7156 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -92,7 +92,7 @@ pub async fn run_format_prompt( zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok() }); - example.prompt = Some(ExamplePrompt { + example.prompt = prompt.map(|prompt| ExamplePrompt { input: prompt, expected_output, rejected_output, diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index d79ded2b9781252855ef424e49247fc1cabd383f..0dce7764e7b9c451b4360fb2177d9d3e0eb7315b 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -204,7 +204,7 @@ pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: Zet .any(|token| input.cursor_excerpt.contains(token)) } -pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String { +pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> Option { format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS) } @@ -416,7 +416,7 @@ pub fn format_prompt_with_budget_for_format( input: &ZetaPromptInput, format: ZetaFormat, max_tokens: usize, -) -> String { +) -> Option { let (context, editable_range, context_range, cursor_offset) = resolve_cursor_region(input, format); let path = &*input.cursor_path; @@ -436,25 +436,24 @@ pub fn format_prompt_with_budget_for_format( input_related_files }; - match format { - ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => { - seed_coder::format_prompt_with_budget( + let prompt = match format { + ZetaFormat::V0211SeedCoder + | ZetaFormat::V0304SeedNoEdits + | ZetaFormat::V0306SeedMultiRegions => { + let mut cursor_section = String::new(); + write_cursor_excerpt_section_for_format( + format, + &mut cursor_section, path, context, &editable_range, cursor_offset, - &input.events, - related_files, - max_tokens, - ) - } - ZetaFormat::V0306SeedMultiRegions => { - let cursor_prefix = - build_v0306_cursor_prefix(path, context, &editable_range, cursor_offset); + ); + seed_coder::assemble_fim_prompt( context, &editable_range, - &cursor_prefix, + &cursor_section, &input.events, related_files, max_tokens, @@ -497,7 +496,12 @@ pub fn format_prompt_with_budget_for_format( prompt.push_str(&cursor_section); prompt } + }; + let prompt_tokens = estimate_tokens(prompt.len()); + if prompt_tokens > max_tokens { + return None; } + return Some(prompt); } pub fn filter_redundant_excerpts( @@ -2707,8 +2711,8 @@ pub mod seed_coder { ) -> String { let suffix_section = build_suffix_section(context, editable_range); - let suffix_tokens = estimate_tokens(suffix_section.len()); - let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len()); + let suffix_tokens = estimate_tokens(suffix_section.len() + FIM_PREFIX.len()); + let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len() + FIM_MIDDLE.len()); let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens); let edit_history_section = super::format_edit_history_within_budget( @@ -2718,8 +2722,9 @@ pub mod seed_coder { budget_after_cursor, max_edit_event_count_for_format(&ZetaFormat::V0211SeedCoder), ); - let edit_history_tokens = estimate_tokens(edit_history_section.len()); - let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens); + let edit_history_tokens = estimate_tokens(edit_history_section.len() + "\n".len()); + let budget_after_edit_history = + budget_after_cursor.saturating_sub(edit_history_tokens + "\n".len()); let related_files_section = super::format_related_files_within_budget( related_files, @@ -2741,6 +2746,7 @@ pub mod seed_coder { } prompt.push_str(cursor_prefix_section); prompt.push_str(FIM_MIDDLE); + prompt } @@ -4087,7 +4093,7 @@ mod tests { } } - fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { + fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> Option { format_prompt_with_budget_for_format(input, ZetaFormat::V0114180EditableRegion, max_tokens) } @@ -4102,7 +4108,7 @@ mod tests { ); assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>related.rs fn helper() {} @@ -4121,6 +4127,7 @@ mod tests { suffix <|fim_middle|>updated "#} + .to_string() ); } @@ -4132,18 +4139,18 @@ mod tests { 2, vec![make_event("a.rs", "-x\n+y\n")], vec![ - make_related_file("r1.rs", "a\n"), - make_related_file("r2.rs", "b\n"), + make_related_file("r1.rs", "aaaaaaa\n"), + make_related_file("r2.rs", "bbbbbbb\n"), ], ); assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>r1.rs - a + aaaaaaa <|file_sep|>r2.rs - b + bbbbbbb <|file_sep|>edit history --- a/a.rs +++ b/a.rs @@ -4156,15 +4163,18 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); assert_eq!( - format_with_budget(&input, 50), - indoc! {r#" - <|file_sep|>r1.rs - a - <|file_sep|>r2.rs - b + format_with_budget(&input, 55), + Some( + indoc! {r#" + <|file_sep|>edit history + --- a/a.rs + +++ b/a.rs + -x + +y <|file_sep|>test.rs <|fim_prefix|> <|fim_middle|>current @@ -4172,6 +4182,8 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() + ) ); } @@ -4207,7 +4219,7 @@ mod tests { ); assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>big.rs first excerpt @@ -4222,10 +4234,11 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); assert_eq!( - format_with_budget(&input, 50), + format_with_budget(&input, 50).unwrap(), indoc! {r#" <|file_sep|>big.rs first excerpt @@ -4237,6 +4250,7 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); } @@ -4275,7 +4289,7 @@ mod tests { // With large budget, both files included; rendered in stable lexicographic order. assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>file_a.rs low priority content @@ -4288,6 +4302,7 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); // With tight budget, only file_b (lower order) fits. @@ -4295,7 +4310,7 @@ mod tests { // file_b header (7) + excerpt (7) = 14 tokens, which fits. // file_a would need another 14 tokens, which doesn't fit. assert_eq!( - format_with_budget(&input, 52), + format_with_budget(&input, 52).unwrap(), indoc! {r#" <|file_sep|>file_b.rs high priority content @@ -4306,6 +4321,7 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); } @@ -4347,7 +4363,7 @@ mod tests { // With large budget, all three excerpts included. assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>mod.rs mod header @@ -4362,11 +4378,12 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); // With tight budget, only order<=1 excerpts included (header + important fn). assert_eq!( - format_with_budget(&input, 55), + format_with_budget(&input, 55).unwrap(), indoc! {r#" <|file_sep|>mod.rs mod header @@ -4380,6 +4397,7 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); } @@ -4394,7 +4412,7 @@ mod tests { ); assert_eq!( - format_with_budget(&input, 10000), + format_with_budget(&input, 10000).unwrap(), indoc! {r#" <|file_sep|>edit history --- a/old.rs @@ -4410,10 +4428,11 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); assert_eq!( - format_with_budget(&input, 55), + format_with_budget(&input, 60).unwrap(), indoc! {r#" <|file_sep|>edit history --- a/new.rs @@ -4426,6 +4445,7 @@ mod tests { <|fim_suffix|> <|fim_middle|>updated "#} + .to_string() ); } @@ -4439,25 +4459,19 @@ mod tests { vec![make_related_file("related.rs", "helper\n")], ); - assert_eq!( - format_with_budget(&input, 30), - indoc! {r#" - <|file_sep|>test.rs - <|fim_prefix|> - <|fim_middle|>current - fn <|user_cursor|>main() {} - <|fim_suffix|> - <|fim_middle|>updated - "#} - ); + assert!(format_with_budget(&input, 30).is_none()) } + #[track_caller] fn format_seed_coder(input: &ZetaPromptInput) -> String { format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, 10000) + .expect("seed coder prompt formatting should succeed") } + #[track_caller] fn format_seed_coder_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String { format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, max_tokens) + .expect("seed coder prompt formatting should succeed") } #[test] @@ -4542,17 +4556,22 @@ mod tests { <[fim-middle]>"#} ); - // With tight budget, context is dropped but cursor section remains assert_eq!( - format_seed_coder_with_budget(&input, 30), + format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 24), + None + ); + + assert_eq!( + format_seed_coder_with_budget(&input, 40), indoc! {r#" <[fim-suffix]> <[fim-prefix]>test.rs <<<<<<< CURRENT co<|user_cursor|>de ======= - <[fim-middle]>"#} - ); + <[fim-middle]>"# + } + ) } #[test] @@ -4603,21 +4622,20 @@ mod tests { <[fim-middle]>"#} ); - // With tight budget, only high_prio included. - // Cursor sections cost 25 tokens, so budget 44 leaves 19 for related files. - // high_prio header (7) + excerpt (3) = 10, fits. low_prio would add 10 more = 20 > 19. + // With tight budget under the generic heuristic, context is dropped but the + // minimal cursor section still fits. assert_eq!( - format_seed_coder_with_budget(&input, 44), - indoc! {r#" - <[fim-suffix]> - <[fim-prefix]>high_prio.rs - high prio - - test.rs - <<<<<<< CURRENT - co<|user_cursor|>de - ======= - <[fim-middle]>"#} + format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 44), + Some( + indoc! {r#" + <[fim-suffix]> + <[fim-prefix]>test.rs + <<<<<<< CURRENT + co<|user_cursor|>de + ======= + <[fim-middle]>"#} + .to_string() + ) ); }