zeta2: Cut oldest events to maintain prompt size budget (#47394)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/zeta_prompt/src/zeta_prompt.rs | 80 +++++++++++++++++++++++++++-
1 file changed, 76 insertions(+), 4 deletions(-)

Detailed changes

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -7,6 +7,11 @@ use std::sync::Arc;
 use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
 
 pub const CURSOR_MARKER: &str = "<|user_cursor|>";
+pub const MAX_PROMPT_TOKENS: usize = 4096;
+
+fn estimate_tokens(bytes: usize) -> usize {
+    bytes / 3
+}
 
 #[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct ZetaPromptInput {
@@ -131,8 +136,8 @@ pub struct RelatedExcerpt {
 
 pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
     let mut prompt = String::new();
-    write_related_files(&mut prompt, &input.related_files);
-    write_edit_history_section(&mut prompt, input);
+    let mut related_file_ranges = write_related_files(&mut prompt, &input.related_files);
+    let mut event_ranges = write_edit_history_section(&mut prompt, input);
 
     match version {
         ZetaVersion::V0112MiddleAtEnd => {
@@ -147,11 +152,70 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri
         }
     }
 
+    truncate_prompt_to_budget(
+        &mut prompt,
+        &mut related_file_ranges,
+        &mut event_ranges,
+        MAX_PROMPT_TOKENS,
+    );
+
     prompt
 }
 
-pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
+fn truncate_prompt_to_budget(
+    prompt: &mut String,
+    related_file_ranges: &mut Vec<Range<usize>>,
+    event_ranges: &mut Vec<Range<usize>>,
+    max_tokens: usize,
+) {
+    let mut remove_from_related_files = true;
+
+    while estimate_tokens(prompt.len()) > max_tokens {
+        let range_to_remove = if remove_from_related_files {
+            related_file_ranges.pop()
+        } else {
+            event_ranges.pop()
+        };
+
+        let Some(range) = range_to_remove else {
+            if remove_from_related_files && !event_ranges.is_empty() {
+                remove_from_related_files = false;
+                continue;
+            } else if !remove_from_related_files && !related_file_ranges.is_empty() {
+                remove_from_related_files = true;
+                continue;
+            } else {
+                break;
+            }
+        };
+
+        let removed_len = range.end - range.start;
+        prompt.replace_range(range.clone(), "");
+
+        for r in related_file_ranges.iter_mut() {
+            if r.start > range.start {
+                r.start -= removed_len;
+                r.end -= removed_len;
+            }
+        }
+        for r in event_ranges.iter_mut() {
+            if r.start > range.start {
+                r.start -= removed_len;
+                r.end -= removed_len;
+            }
+        }
+
+        remove_from_related_files = !remove_from_related_files;
+    }
+}
+
+pub fn write_related_files(
+    prompt: &mut String,
+    related_files: &[RelatedFile],
+) -> Vec<Range<usize>> {
+    let mut ranges = Vec::new();
     for file in related_files {
+        let start = prompt.len();
         let path_str = file.path.to_string_lossy();
         write!(prompt, "<|file_sep|>{}\n", path_str).ok();
         for excerpt in &file.excerpts {
@@ -163,14 +227,22 @@ pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
                 prompt.push_str("...\n");
             }
         }
+        let end = prompt.len();
+        ranges.push(start..end);
     }
+    ranges
 }
 
-fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
+fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) -> Vec<Range<usize>> {
+    let mut ranges = Vec::new();
     prompt.push_str("<|file_sep|>edit history\n");
     for event in &input.events {
+        let start = prompt.len();
         write_event(prompt, event);
+        let end = prompt.len();
+        ranges.push(start..end);
     }
+    ranges
 }
 
 mod v0112_middle_at_end {