zeta2: Improve format prompt budgeting (#47808)

Ben Kunkle and Max created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

---------

Co-authored-by: Max <max@zed.dev>

Change summary

Cargo.lock                            |   1 
crates/zeta_prompt/Cargo.toml         |   3 
crates/zeta_prompt/src/zeta_prompt.rs | 425 ++++++++++++++++++++++++----
3 files changed, 367 insertions(+), 62 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21365,6 +21365,7 @@ name = "zeta_prompt"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "indoc",
  "serde",
  "strum 0.27.2",
 ]

crates/zeta_prompt/Cargo.toml 🔗

@@ -15,3 +15,6 @@ path = "src/zeta_prompt.rs"
 anyhow.workspace = true
 serde.workspace = true
 strum.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -135,78 +135,129 @@ pub struct RelatedExcerpt {
 }
 
 pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
-    let mut prompt = String::new();
-    let mut related_file_ranges = write_related_files(&mut prompt, &input.related_files);
-    let mut event_ranges = write_edit_history_section(&mut prompt, input);
+    format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
+}
 
+fn format_zeta_prompt_with_budget(
+    input: &ZetaPromptInput,
+    version: ZetaVersion,
+    max_tokens: usize,
+) -> String {
+    let mut cursor_section = String::new();
     match version {
         ZetaVersion::V0112MiddleAtEnd => {
-            v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
+            v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
         }
         ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
-            v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
+            v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
         }
-
         ZetaVersion::V0120GitMergeMarkers => {
-            v0120_git_merge_markers::write_cursor_excerpt_section(&mut prompt, input)
+            v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
         }
     }
 
-    truncate_prompt_to_budget(
-        &mut prompt,
-        &mut related_file_ranges,
-        &mut event_ranges,
-        MAX_PROMPT_TOKENS,
-    );
+    let cursor_tokens = estimate_tokens(cursor_section.len());
+    let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
+
+    let edit_history_section =
+        format_edit_history_within_budget(&input.events, budget_after_cursor);
+    let edit_history_tokens = estimate_tokens(edit_history_section.len());
+    let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+
+    let related_files_section =
+        format_related_files_within_budget(&input.related_files, budget_after_edit_history);
 
+    let mut prompt = String::new();
+    prompt.push_str(&related_files_section);
+    prompt.push_str(&edit_history_section);
+    prompt.push_str(&cursor_section);
     prompt
 }
 
-fn truncate_prompt_to_budget(
-    prompt: &mut String,
-    related_file_ranges: &mut Vec<Range<usize>>,
-    event_ranges: &mut Vec<Range<usize>>,
-    max_tokens: usize,
-) {
-    let mut remove_from_related_files = true;
-
-    while estimate_tokens(prompt.len()) > max_tokens {
-        let range_to_remove = if remove_from_related_files {
-            related_file_ranges.pop()
-        } else {
-            event_ranges.pop()
-        };
+fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
+    let header = "<|file_sep|>edit history\n";
+    let header_tokens = estimate_tokens(header.len());
+    if header_tokens >= max_tokens {
+        return String::new();
+    }
 
-        let Some(range) = range_to_remove else {
-            if remove_from_related_files && !event_ranges.is_empty() {
-                remove_from_related_files = false;
-                continue;
-            } else if !remove_from_related_files && !related_file_ranges.is_empty() {
-                remove_from_related_files = true;
-                continue;
-            } else {
-                break;
-            }
-        };
+    let mut event_strings: Vec<String> = Vec::new();
+    let mut total_tokens = header_tokens;
 
-        let removed_len = range.end - range.start;
-        prompt.replace_range(range.clone(), "");
+    for event in events.iter().rev() {
+        let mut event_str = String::new();
+        write_event(&mut event_str, event);
+        let event_tokens = estimate_tokens(event_str.len());
 
-        for r in related_file_ranges.iter_mut() {
-            if r.start > range.start {
-                r.start -= removed_len;
-                r.end -= removed_len;
-            }
+        if total_tokens + event_tokens > max_tokens {
+            break;
+        }
+        total_tokens += event_tokens;
+        event_strings.push(event_str);
+    }
+
+    if event_strings.is_empty() {
+        return String::new();
+    }
+
+    let mut result = String::from(header);
+    for event_str in event_strings.iter().rev() {
+        result.push_str(&event_str);
+    }
+    result
+}
+
+fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
+    let mut result = String::new();
+    let mut total_tokens = 0;
+
+    for file in related_files {
+        let path_str = file.path.to_string_lossy();
+        let header_len = "<|file_sep|>".len() + path_str.len() + 1;
+        let header_tokens = estimate_tokens(header_len);
+
+        if total_tokens + header_tokens > max_tokens {
+            break;
         }
-        for r in event_ranges.iter_mut() {
-            if r.start > range.start {
-                r.start -= removed_len;
-                r.end -= removed_len;
+
+        let mut file_tokens = header_tokens;
+        let mut excerpts_to_include = 0;
+
+        for excerpt in &file.excerpts {
+            let needs_newline = !excerpt.text.ends_with('\n');
+            let needs_ellipsis = excerpt.row_range.end < file.max_row;
+            let excerpt_len = excerpt.text.len()
+                + if needs_newline { "\n".len() } else { "".len() }
+                + if needs_ellipsis {
+                    "...\n".len()
+                } else {
+                    "".len()
+                };
+
+            let excerpt_tokens = estimate_tokens(excerpt_len);
+            if total_tokens + file_tokens + excerpt_tokens > max_tokens {
+                break;
             }
+            file_tokens += excerpt_tokens;
+            excerpts_to_include += 1;
         }
 
-        remove_from_related_files = !remove_from_related_files;
+        if excerpts_to_include > 0 {
+            total_tokens += file_tokens;
+            write!(result, "<|file_sep|>{}\n", path_str).ok();
+            for excerpt in file.excerpts.iter().take(excerpts_to_include) {
+                result.push_str(&excerpt.text);
+                if !result.ends_with('\n') {
+                    result.push('\n');
+                }
+                if excerpt.row_range.end < file.max_row {
+                    result.push_str("...\n");
+                }
+            }
+        }
     }
+
+    result
 }
 
 pub fn write_related_files(
@@ -233,18 +284,6 @@ pub fn write_related_files(
     ranges
 }
 
-fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) -> Vec<Range<usize>> {
-    let mut ranges = Vec::new();
-    prompt.push_str("<|file_sep|>edit history\n");
-    for event in &input.events {
-        let start = prompt.len();
-        write_event(prompt, event);
-        let end = prompt.len();
-        ranges.push(start..end);
-    }
-    ranges
-}
-
 mod v0112_middle_at_end {
     use super::*;
 
@@ -376,3 +415,265 @@ pub mod v0120_git_merge_markers {
         prompt.push_str(SEPARATOR);
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+
+    fn make_input(
+        cursor_excerpt: &str,
+        editable_range: Range<usize>,
+        cursor_offset: usize,
+        events: Vec<Event>,
+        related_files: Vec<RelatedFile>,
+    ) -> ZetaPromptInput {
+        ZetaPromptInput {
+            cursor_path: Path::new("test.rs").into(),
+            cursor_excerpt: cursor_excerpt.into(),
+            editable_range_in_excerpt: editable_range,
+            cursor_offset_in_excerpt: cursor_offset,
+            events: events.into_iter().map(Arc::new).collect(),
+            related_files,
+        }
+    }
+
+    fn make_event(path: &str, diff: &str) -> Event {
+        Event::BufferChange {
+            path: Path::new(path).into(),
+            old_path: Path::new(path).into(),
+            diff: diff.to_string(),
+            predicted: false,
+            in_open_source_repo: false,
+        }
+    }
+
+    fn make_related_file(path: &str, content: &str) -> RelatedFile {
+        RelatedFile {
+            path: Path::new(path).into(),
+            max_row: content.lines().count() as u32,
+            excerpts: vec![RelatedExcerpt {
+                row_range: 0..content.lines().count() as u32,
+                text: content.into(),
+            }],
+        }
+    }
+
+    fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
+        format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
+    }
+
+    #[test]
+    fn test_no_truncation_when_within_budget() {
+        let input = make_input(
+            "prefix\neditable\nsuffix",
+            7..15,
+            10,
+            vec![make_event("a.rs", "-old\n+new\n")],
+            vec![make_related_file("related.rs", "fn helper() {}\n")],
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 10000),
+            indoc! {r#"
+                <|file_sep|>related.rs
+                fn helper() {}
+                <|file_sep|>edit history
+                --- a/a.rs
+                +++ b/a.rs
+                -old
+                +new
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                prefix
+                <|fim_middle|>current
+                edi<|user_cursor|>table
+                <|fim_suffix|>
+
+                suffix
+                <|fim_middle|>updated
+            "#}
+        );
+    }
+
+    #[test]
+    fn test_truncation_drops_edit_history_when_budget_tight() {
+        let input = make_input(
+            "code",
+            0..4,
+            2,
+            vec![make_event("a.rs", "-x\n+y\n")],
+            vec![
+                make_related_file("r1.rs", "a\n"),
+                make_related_file("r2.rs", "b\n"),
+            ],
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 10000),
+            indoc! {r#"
+                <|file_sep|>r1.rs
+                a
+                <|file_sep|>r2.rs
+                b
+                <|file_sep|>edit history
+                --- a/a.rs
+                +++ b/a.rs
+                -x
+                +y
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                co<|user_cursor|>de
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 50),
+            indoc! {r#"
+                <|file_sep|>r1.rs
+                a
+                <|file_sep|>r2.rs
+                b
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                co<|user_cursor|>de
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+    }
+
+    #[test]
+    fn test_truncation_includes_partial_excerpts() {
+        let input = make_input(
+            "x",
+            0..1,
+            0,
+            vec![],
+            vec![RelatedFile {
+                path: Path::new("big.rs").into(),
+                max_row: 30,
+                excerpts: vec![
+                    RelatedExcerpt {
+                        row_range: 0..10,
+                        text: "first excerpt\n".into(),
+                    },
+                    RelatedExcerpt {
+                        row_range: 10..20,
+                        text: "second excerpt\n".into(),
+                    },
+                    RelatedExcerpt {
+                        row_range: 20..30,
+                        text: "third excerpt\n".into(),
+                    },
+                ],
+            }],
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 10000),
+            indoc! {r#"
+                <|file_sep|>big.rs
+                first excerpt
+                ...
+                second excerpt
+                ...
+                third excerpt
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                <|user_cursor|>x
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 50),
+            indoc! {r#"
+                <|file_sep|>big.rs
+                first excerpt
+                ...
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                <|user_cursor|>x
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+    }
+
+    #[test]
+    fn test_truncation_drops_older_events_first() {
+        let input = make_input(
+            "x",
+            0..1,
+            0,
+            vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
+            vec![],
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 10000),
+            indoc! {r#"
+                <|file_sep|>edit history
+                --- a/old.rs
+                +++ b/old.rs
+                -1
+                --- a/new.rs
+                +++ b/new.rs
+                -2
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                <|user_cursor|>x
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 55),
+            indoc! {r#"
+                <|file_sep|>edit history
+                --- a/new.rs
+                +++ b/new.rs
+                -2
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                <|user_cursor|>x
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+    }
+
+    #[test]
+    fn test_cursor_excerpt_always_included_with_minimal_budget() {
+        let input = make_input(
+            "fn main() {}",
+            0..12,
+            3,
+            vec![make_event("a.rs", "-old\n+new\n")],
+            vec![make_related_file("related.rs", "helper\n")],
+        );
+
+        assert_eq!(
+            format_with_budget(&input, 30),
+            indoc! {r#"
+                <|file_sep|>test.rs
+                <|fim_prefix|>
+                <|fim_middle|>current
+                fn <|user_cursor|>main() {}
+                <|fim_suffix|>
+                <|fim_middle|>updated
+            "#}
+        );
+    }
+}