ep: Ensure prompt is always within token limit (#51529)

Ben Kunkle created

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction/src/edit_prediction_tests.rs |   1 
crates/edit_prediction/src/zeta.rs                  |  40 ++-
crates/edit_prediction_cli/src/format_prompt.rs     |   2 
crates/zeta_prompt/src/zeta_prompt.rs               | 154 ++++++++------
4 files changed, 113 insertions(+), 84 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -2270,6 +2270,7 @@ fn empty_response() -> PredictEditsV3Response {
 
 fn prompt_from_request(request: &PredictEditsV3Request) -> String {
     zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
+        .expect("default zeta prompt formatting should succeed in edit prediction tests")
 }
 
 fn assert_no_predict_request_ready(

crates/edit_prediction/src/zeta.rs 🔗

@@ -130,13 +130,14 @@ pub fn request_prediction_with_zeta(
                 return Err(anyhow::anyhow!("prompt contains special tokens"));
             }
 
+            let formatted_prompt = format_zeta_prompt(&prompt_input, zeta_version);
+
             if let Some(debug_tx) = &debug_tx {
-                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
                 debug_tx
                     .unbounded_send(DebugEvent::EditPredictionStarted(
                         EditPredictionStartedDebugEvent {
                             buffer: buffer.downgrade(),
-                            prompt: Some(prompt),
+                            prompt: formatted_prompt.clone(),
                             position,
                         },
                     ))
@@ -145,11 +146,11 @@ pub fn request_prediction_with_zeta(
 
             log::trace!("Sending edit prediction request");
 
-            let (request_id, output, model_version, usage) =
-                if let Some(custom_settings) = &custom_server_settings {
+            let Some((request_id, output, model_version, usage)) =
+                (if let Some(custom_settings) = &custom_server_settings {
                     let max_tokens = custom_settings.max_output_tokens * 4;
 
-                    match custom_settings.prompt_format {
+                    Some(match custom_settings.prompt_format {
                         EditPredictionPromptFormat::Zeta => {
                             let ranges = &prompt_input.excerpt_ranges;
                             let editable_range_in_excerpt = ranges.editable_350.clone();
@@ -186,7 +187,9 @@ pub fn request_prediction_with_zeta(
                             (request_id, parsed_output, None, None)
                         }
                         EditPredictionPromptFormat::Zeta2 => {
-                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+                            let Some(prompt) = formatted_prompt.clone() else {
+                                return Ok((None, None));
+                            };
                             let prefill = get_prefill(&prompt_input, zeta_version);
                             let prompt = format!("{prompt}{prefill}");
 
@@ -219,9 +222,11 @@ pub fn request_prediction_with_zeta(
                             (request_id, output_text, None, None)
                         }
                         _ => anyhow::bail!("unsupported prompt format"),
-                    }
+                    })
                 } else if let Some(config) = &raw_config {
-                    let prompt = format_zeta_prompt(&prompt_input, config.format);
+                    let Some(prompt) = format_zeta_prompt(&prompt_input, config.format) else {
+                        return Ok((None, None));
+                    };
                     let prefill = get_prefill(&prompt_input, config.format);
                     let prompt = format!("{prompt}{prefill}");
                     let environment = config
@@ -263,7 +268,7 @@ pub fn request_prediction_with_zeta(
                         None
                     };
 
-                    (request_id, output, None, usage)
+                    Some((request_id, output, None, usage))
                 } else {
                     // Use V3 endpoint - server handles model/version selection and suffix stripping
                     let (response, usage) = EditPredictionStore::send_v3_request(
@@ -284,8 +289,11 @@ pub fn request_prediction_with_zeta(
                         range_in_excerpt: response.editable_range,
                     };
 
-                    (request_id, Some(parsed_output), model_version, usage)
-                };
+                    Some((request_id, Some(parsed_output), model_version, usage))
+                })
+            else {
+                return Ok((None, None));
+            };
 
             let received_response_at = Instant::now();
 
@@ -296,7 +304,7 @@ pub fn request_prediction_with_zeta(
                 range_in_excerpt: editable_range_in_excerpt,
             }) = output
             else {
-                return Ok(((request_id, None), None));
+                return Ok((Some((request_id, None)), None));
             };
 
             let editable_range_in_buffer = editable_range_in_excerpt.start
@@ -342,7 +350,7 @@ pub fn request_prediction_with_zeta(
             );
 
             anyhow::Ok((
-                (
+                Some((
                     request_id,
                     Some(Prediction {
                         prompt_input,
@@ -354,14 +362,16 @@ pub fn request_prediction_with_zeta(
                         editable_range_in_buffer,
                         model_version,
                     }),
-                ),
+                )),
                 usage,
             ))
         }
     });
 
     cx.spawn(async move |this, cx| {
-        let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
+        let Some((id, prediction)) = handle_api_response(&this, request_task.await, cx)? else {
+            return Ok(None);
+        };
 
         let Some(Prediction {
             prompt_input: inputs,

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -92,7 +92,7 @@ pub async fn run_format_prompt(
                 zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
             });
 
-            example.prompt = Some(ExamplePrompt {
+            example.prompt = prompt.map(|prompt| ExamplePrompt {
                 input: prompt,
                 expected_output,
                 rejected_output,

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -204,7 +204,7 @@ pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: Zet
         .any(|token| input.cursor_excerpt.contains(token))
 }
 
-pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> Option<String> {
     format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS)
 }
 
@@ -416,7 +416,7 @@ pub fn format_prompt_with_budget_for_format(
     input: &ZetaPromptInput,
     format: ZetaFormat,
     max_tokens: usize,
-) -> String {
+) -> Option<String> {
     let (context, editable_range, context_range, cursor_offset) =
         resolve_cursor_region(input, format);
     let path = &*input.cursor_path;
@@ -436,25 +436,24 @@ pub fn format_prompt_with_budget_for_format(
         input_related_files
     };
 
-    match format {
-        ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => {
-            seed_coder::format_prompt_with_budget(
+    let prompt = match format {
+        ZetaFormat::V0211SeedCoder
+        | ZetaFormat::V0304SeedNoEdits
+        | ZetaFormat::V0306SeedMultiRegions => {
+            let mut cursor_section = String::new();
+            write_cursor_excerpt_section_for_format(
+                format,
+                &mut cursor_section,
                 path,
                 context,
                 &editable_range,
                 cursor_offset,
-                &input.events,
-                related_files,
-                max_tokens,
-            )
-        }
-        ZetaFormat::V0306SeedMultiRegions => {
-            let cursor_prefix =
-                build_v0306_cursor_prefix(path, context, &editable_range, cursor_offset);
+            );
+
             seed_coder::assemble_fim_prompt(
                 context,
                 &editable_range,
-                &cursor_prefix,
+                &cursor_section,
                 &input.events,
                 related_files,
                 max_tokens,
@@ -497,7 +496,12 @@ pub fn format_prompt_with_budget_for_format(
             prompt.push_str(&cursor_section);
             prompt
         }
+    };
+    let prompt_tokens = estimate_tokens(prompt.len());
+    if prompt_tokens > max_tokens {
+        return None;
     }
+    return Some(prompt);
 }
 
 pub fn filter_redundant_excerpts(
@@ -2707,8 +2711,8 @@ pub mod seed_coder {
     ) -> String {
         let suffix_section = build_suffix_section(context, editable_range);
 
-        let suffix_tokens = estimate_tokens(suffix_section.len());
-        let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len());
+        let suffix_tokens = estimate_tokens(suffix_section.len() + FIM_PREFIX.len());
+        let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len() + FIM_MIDDLE.len());
         let budget_after_cursor = max_tokens.saturating_sub(suffix_tokens + cursor_prefix_tokens);
 
         let edit_history_section = super::format_edit_history_within_budget(
@@ -2718,8 +2722,9 @@ pub mod seed_coder {
             budget_after_cursor,
             max_edit_event_count_for_format(&ZetaFormat::V0211SeedCoder),
         );
-        let edit_history_tokens = estimate_tokens(edit_history_section.len());
-        let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+        let edit_history_tokens = estimate_tokens(edit_history_section.len() + "\n".len());
+        let budget_after_edit_history =
+            budget_after_cursor.saturating_sub(edit_history_tokens + "\n".len());
 
         let related_files_section = super::format_related_files_within_budget(
             related_files,
@@ -2741,6 +2746,7 @@ pub mod seed_coder {
         }
         prompt.push_str(cursor_prefix_section);
         prompt.push_str(FIM_MIDDLE);
+
         prompt
     }
 
@@ -4087,7 +4093,7 @@ mod tests {
         }
     }
 
-    fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
+    fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> Option<String> {
         format_prompt_with_budget_for_format(input, ZetaFormat::V0114180EditableRegion, max_tokens)
     }
 
@@ -4102,7 +4108,7 @@ mod tests {
         );
 
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>related.rs
                 fn helper() {}
@@ -4121,6 +4127,7 @@ mod tests {
                 suffix
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
     }
 
@@ -4132,18 +4139,18 @@ mod tests {
             2,
             vec![make_event("a.rs", "-x\n+y\n")],
             vec![
-                make_related_file("r1.rs", "a\n"),
-                make_related_file("r2.rs", "b\n"),
+                make_related_file("r1.rs", "aaaaaaa\n"),
+                make_related_file("r2.rs", "bbbbbbb\n"),
             ],
         );
 
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>r1.rs
-                a
+                aaaaaaa
                 <|file_sep|>r2.rs
-                b
+                bbbbbbb
                 <|file_sep|>edit history
                 --- a/a.rs
                 +++ b/a.rs
@@ -4156,15 +4163,18 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
 
         assert_eq!(
-            format_with_budget(&input, 50),
-            indoc! {r#"
-                <|file_sep|>r1.rs
-                a
-                <|file_sep|>r2.rs
-                b
+            format_with_budget(&input, 55),
+            Some(
+                indoc! {r#"
+                <|file_sep|>edit history
+                --- a/a.rs
+                +++ b/a.rs
+                -x
+                +y
                 <|file_sep|>test.rs
                 <|fim_prefix|>
                 <|fim_middle|>current
@@ -4172,6 +4182,8 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+                .to_string()
+            )
         );
     }
 
@@ -4207,7 +4219,7 @@ mod tests {
         );
 
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>big.rs
                 first excerpt
@@ -4222,10 +4234,11 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
 
         assert_eq!(
-            format_with_budget(&input, 50),
+            format_with_budget(&input, 50).unwrap(),
             indoc! {r#"
                 <|file_sep|>big.rs
                 first excerpt
@@ -4237,6 +4250,7 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
     }
 
@@ -4275,7 +4289,7 @@ mod tests {
 
         // With large budget, both files included; rendered in stable lexicographic order.
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>file_a.rs
                 low priority content
@@ -4288,6 +4302,7 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
 
         // With tight budget, only file_b (lower order) fits.
@@ -4295,7 +4310,7 @@ mod tests {
         // file_b header (7) + excerpt (7) = 14 tokens, which fits.
         // file_a would need another 14 tokens, which doesn't fit.
         assert_eq!(
-            format_with_budget(&input, 52),
+            format_with_budget(&input, 52).unwrap(),
             indoc! {r#"
                 <|file_sep|>file_b.rs
                 high priority content
@@ -4306,6 +4321,7 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
     }
 
@@ -4347,7 +4363,7 @@ mod tests {
 
         // With large budget, all three excerpts included.
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>mod.rs
                 mod header
@@ -4362,11 +4378,12 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
 
         // With tight budget, only order<=1 excerpts included (header + important fn).
         assert_eq!(
-            format_with_budget(&input, 55),
+            format_with_budget(&input, 55).unwrap(),
             indoc! {r#"
                 <|file_sep|>mod.rs
                 mod header
@@ -4380,6 +4397,7 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
     }
 
@@ -4394,7 +4412,7 @@ mod tests {
         );
 
         assert_eq!(
-            format_with_budget(&input, 10000),
+            format_with_budget(&input, 10000).unwrap(),
             indoc! {r#"
                 <|file_sep|>edit history
                 --- a/old.rs
@@ -4410,10 +4428,11 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
 
         assert_eq!(
-            format_with_budget(&input, 55),
+            format_with_budget(&input, 60).unwrap(),
             indoc! {r#"
                 <|file_sep|>edit history
                 --- a/new.rs
@@ -4426,6 +4445,7 @@ mod tests {
                 <|fim_suffix|>
                 <|fim_middle|>updated
             "#}
+            .to_string()
         );
     }
 
@@ -4439,25 +4459,19 @@ mod tests {
             vec![make_related_file("related.rs", "helper\n")],
         );
 
-        assert_eq!(
-            format_with_budget(&input, 30),
-            indoc! {r#"
-                <|file_sep|>test.rs
-                <|fim_prefix|>
-                <|fim_middle|>current
-                fn <|user_cursor|>main() {}
-                <|fim_suffix|>
-                <|fim_middle|>updated
-            "#}
-        );
+        assert!(format_with_budget(&input, 30).is_none())
     }
 
+    #[track_caller]
     fn format_seed_coder(input: &ZetaPromptInput) -> String {
         format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, 10000)
+            .expect("seed coder prompt formatting should succeed")
     }
 
+    #[track_caller]
     fn format_seed_coder_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
         format_prompt_with_budget_for_format(input, ZetaFormat::V0211SeedCoder, max_tokens)
+            .expect("seed coder prompt formatting should succeed")
     }
 
     #[test]
@@ -4542,17 +4556,22 @@ mod tests {
                 <[fim-middle]>"#}
         );
 
-        // With tight budget, context is dropped but cursor section remains
         assert_eq!(
-            format_seed_coder_with_budget(&input, 30),
+            format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 24),
+            None
+        );
+
+        assert_eq!(
+            format_seed_coder_with_budget(&input, 40),
             indoc! {r#"
                 <[fim-suffix]>
                 <[fim-prefix]><filename>test.rs
                 <<<<<<< CURRENT
                 co<|user_cursor|>de
                 =======
-                <[fim-middle]>"#}
-        );
+                <[fim-middle]>"#
+            }
+        )
     }
 
     #[test]
@@ -4603,21 +4622,20 @@ mod tests {
                 <[fim-middle]>"#}
         );
 
-        // With tight budget, only high_prio included.
-        // Cursor sections cost 25 tokens, so budget 44 leaves 19 for related files.
-        // high_prio header (7) + excerpt (3) = 10, fits. low_prio would add 10 more = 20 > 19.
+        // With tight budget under the generic heuristic, context is dropped but the
+        // minimal cursor section still fits.
         assert_eq!(
-            format_seed_coder_with_budget(&input, 44),
-            indoc! {r#"
-                <[fim-suffix]>
-                <[fim-prefix]><filename>high_prio.rs
-                high prio
-
-                <filename>test.rs
-                <<<<<<< CURRENT
-                co<|user_cursor|>de
-                =======
-                <[fim-middle]>"#}
+            format_prompt_with_budget_for_format(&input, ZetaFormat::V0211SeedCoder, 44),
+            Some(
+                indoc! {r#"
+                    <[fim-suffix]>
+                    <[fim-prefix]><filename>test.rs
+                    <<<<<<< CURRENT
+                    co<|user_cursor|>de
+                    =======
+                    <[fim-middle]>"#}
+                .to_string()
+            )
         );
     }