diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 12abeba846e2f413259b3baada7ea82b171b8a4f..59fa648c9f952302cc95bff607ee9035e026e7ea 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -7,6 +7,11 @@ use std::sync::Arc; use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr}; pub const CURSOR_MARKER: &str = "<|user_cursor|>"; +pub const MAX_PROMPT_TOKENS: usize = 4096; + +fn estimate_tokens(bytes: usize) -> usize { + bytes / 3 +} #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ZetaPromptInput { @@ -131,8 +136,8 @@ pub struct RelatedExcerpt { pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String { let mut prompt = String::new(); - write_related_files(&mut prompt, &input.related_files); - write_edit_history_section(&mut prompt, input); + let mut related_file_ranges = write_related_files(&mut prompt, &input.related_files); + let mut event_ranges = write_edit_history_section(&mut prompt, input); match version { ZetaVersion::V0112MiddleAtEnd => { @@ -147,11 +152,70 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri } } + truncate_prompt_to_budget( + &mut prompt, + &mut related_file_ranges, + &mut event_ranges, + MAX_PROMPT_TOKENS, + ); + prompt } -pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) { +fn truncate_prompt_to_budget( + prompt: &mut String, + related_file_ranges: &mut Vec>, + event_ranges: &mut Vec>, + max_tokens: usize, +) { + let mut remove_from_related_files = true; + + while estimate_tokens(prompt.len()) > max_tokens { + let range_to_remove = if remove_from_related_files { + related_file_ranges.pop() + } else { + event_ranges.pop() + }; + + let Some(range) = range_to_remove else { + if remove_from_related_files && !event_ranges.is_empty() { + remove_from_related_files = false; + continue; + } else if !remove_from_related_files && !related_file_ranges.is_empty() { + remove_from_related_files = true; + continue; + } else { + break; + } + }; + + let removed_len = range.end - range.start; + prompt.replace_range(range.clone(), ""); + + for r in related_file_ranges.iter_mut() { + if r.start > range.start { + r.start -= removed_len; + r.end -= removed_len; + } + } + for r in event_ranges.iter_mut() { + if r.start > range.start { + r.start -= removed_len; + r.end -= removed_len; + } + } + + remove_from_related_files = !remove_from_related_files; + } +} + +pub fn write_related_files( + prompt: &mut String, + related_files: &[RelatedFile], +) -> Vec> { + let mut ranges = Vec::new(); for file in related_files { + let start = prompt.len(); let path_str = file.path.to_string_lossy(); write!(prompt, "<|file_sep|>{}\n", path_str).ok(); for excerpt in &file.excerpts { @@ -163,14 +227,22 @@ pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) { prompt.push_str("...\n"); } } + let end = prompt.len(); + ranges.push(start..end); } + ranges } -fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) { +fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) -> Vec> { + let mut ranges = Vec::new(); prompt.push_str("<|file_sep|>edit history\n"); for event in &input.events { + let start = prompt.len(); write_event(prompt, event); + let end = prompt.len(); + ranges.push(start..end); } + ranges } mod v0112_middle_at_end {