Cargo.lock 🔗
@@ -21365,6 +21365,7 @@ name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
+ "indoc",
"serde",
"strum 0.27.2",
]
Ben Kunkle and Max created
Closes #ISSUE
Release Notes:
- N/A *or* Added/Fixed/Improved ...
---------
Co-authored-by: Max <max@zed.dev>
Cargo.lock | 1
crates/zeta_prompt/Cargo.toml | 3
crates/zeta_prompt/src/zeta_prompt.rs | 425 ++++++++++++++++++++++++----
3 files changed, 367 insertions(+), 62 deletions(-)
@@ -21365,6 +21365,7 @@ name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
+ "indoc",
"serde",
"strum 0.27.2",
]
@@ -15,3 +15,6 @@ path = "src/zeta_prompt.rs"
anyhow.workspace = true
serde.workspace = true
strum.workspace = true
+
+[dev-dependencies]
+indoc.workspace = true
@@ -135,78 +135,129 @@ pub struct RelatedExcerpt {
}
pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
- let mut prompt = String::new();
- let mut related_file_ranges = write_related_files(&mut prompt, &input.related_files);
- let mut event_ranges = write_edit_history_section(&mut prompt, input);
+ format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
+}
+fn format_zeta_prompt_with_budget(
+ input: &ZetaPromptInput,
+ version: ZetaVersion,
+ max_tokens: usize,
+) -> String {
+ let mut cursor_section = String::new();
match version {
ZetaVersion::V0112MiddleAtEnd => {
- v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
+ v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
}
ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
- v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
+ v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
}
-
ZetaVersion::V0120GitMergeMarkers => {
- v0120_git_merge_markers::write_cursor_excerpt_section(&mut prompt, input)
+ v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
}
}
- truncate_prompt_to_budget(
- &mut prompt,
- &mut related_file_ranges,
- &mut event_ranges,
- MAX_PROMPT_TOKENS,
- );
+ let cursor_tokens = estimate_tokens(cursor_section.len());
+ let budget_after_cursor = max_tokens.saturating_sub(cursor_tokens);
+
+ let edit_history_section =
+ format_edit_history_within_budget(&input.events, budget_after_cursor);
+ let edit_history_tokens = estimate_tokens(edit_history_section.len());
+ let budget_after_edit_history = budget_after_cursor.saturating_sub(edit_history_tokens);
+
+ let related_files_section =
+ format_related_files_within_budget(&input.related_files, budget_after_edit_history);
+ let mut prompt = String::new();
+ prompt.push_str(&related_files_section);
+ prompt.push_str(&edit_history_section);
+ prompt.push_str(&cursor_section);
prompt
}
-fn truncate_prompt_to_budget(
- prompt: &mut String,
- related_file_ranges: &mut Vec<Range<usize>>,
- event_ranges: &mut Vec<Range<usize>>,
- max_tokens: usize,
-) {
- let mut remove_from_related_files = true;
-
- while estimate_tokens(prompt.len()) > max_tokens {
- let range_to_remove = if remove_from_related_files {
- related_file_ranges.pop()
- } else {
- event_ranges.pop()
- };
+fn format_edit_history_within_budget(events: &[Arc<Event>], max_tokens: usize) -> String {
+ let header = "<|file_sep|>edit history\n";
+ let header_tokens = estimate_tokens(header.len());
+ if header_tokens >= max_tokens {
+ return String::new();
+ }
- let Some(range) = range_to_remove else {
- if remove_from_related_files && !event_ranges.is_empty() {
- remove_from_related_files = false;
- continue;
- } else if !remove_from_related_files && !related_file_ranges.is_empty() {
- remove_from_related_files = true;
- continue;
- } else {
- break;
- }
- };
+ let mut event_strings: Vec<String> = Vec::new();
+ let mut total_tokens = header_tokens;
- let removed_len = range.end - range.start;
- prompt.replace_range(range.clone(), "");
+ for event in events.iter().rev() {
+ let mut event_str = String::new();
+ write_event(&mut event_str, event);
+ let event_tokens = estimate_tokens(event_str.len());
- for r in related_file_ranges.iter_mut() {
- if r.start > range.start {
- r.start -= removed_len;
- r.end -= removed_len;
- }
+ if total_tokens + event_tokens > max_tokens {
+ break;
+ }
+ total_tokens += event_tokens;
+ event_strings.push(event_str);
+ }
+
+ if event_strings.is_empty() {
+ return String::new();
+ }
+
+ let mut result = String::from(header);
+ for event_str in event_strings.iter().rev() {
+ result.push_str(&event_str);
+ }
+ result
+}
+
+fn format_related_files_within_budget(related_files: &[RelatedFile], max_tokens: usize) -> String {
+ let mut result = String::new();
+ let mut total_tokens = 0;
+
+ for file in related_files {
+ let path_str = file.path.to_string_lossy();
+ let header_len = "<|file_sep|>".len() + path_str.len() + 1;
+ let header_tokens = estimate_tokens(header_len);
+
+ if total_tokens + header_tokens > max_tokens {
+ break;
}
- for r in event_ranges.iter_mut() {
- if r.start > range.start {
- r.start -= removed_len;
- r.end -= removed_len;
+
+ let mut file_tokens = header_tokens;
+ let mut excerpts_to_include = 0;
+
+ for excerpt in &file.excerpts {
+ let needs_newline = !excerpt.text.ends_with('\n');
+ let needs_ellipsis = excerpt.row_range.end < file.max_row;
+ let excerpt_len = excerpt.text.len()
+ + if needs_newline { "\n".len() } else { "".len() }
+ + if needs_ellipsis {
+ "...\n".len()
+ } else {
+ "".len()
+ };
+
+ let excerpt_tokens = estimate_tokens(excerpt_len);
+ if total_tokens + file_tokens + excerpt_tokens > max_tokens {
+ break;
}
+ file_tokens += excerpt_tokens;
+ excerpts_to_include += 1;
}
- remove_from_related_files = !remove_from_related_files;
+ if excerpts_to_include > 0 {
+ total_tokens += file_tokens;
+ write!(result, "<|file_sep|>{}\n", path_str).ok();
+ for excerpt in file.excerpts.iter().take(excerpts_to_include) {
+ result.push_str(&excerpt.text);
+ if !result.ends_with('\n') {
+ result.push('\n');
+ }
+ if excerpt.row_range.end < file.max_row {
+ result.push_str("...\n");
+ }
+ }
+ }
}
+
+ result
}
pub fn write_related_files(
@@ -233,18 +284,6 @@ pub fn write_related_files(
ranges
}
-fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) -> Vec<Range<usize>> {
- let mut ranges = Vec::new();
- prompt.push_str("<|file_sep|>edit history\n");
- for event in &input.events {
- let start = prompt.len();
- write_event(prompt, event);
- let end = prompt.len();
- ranges.push(start..end);
- }
- ranges
-}
-
mod v0112_middle_at_end {
use super::*;
@@ -376,3 +415,265 @@ pub mod v0120_git_merge_markers {
prompt.push_str(SEPARATOR);
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ fn make_input(
+ cursor_excerpt: &str,
+ editable_range: Range<usize>,
+ cursor_offset: usize,
+ events: Vec<Event>,
+ related_files: Vec<RelatedFile>,
+ ) -> ZetaPromptInput {
+ ZetaPromptInput {
+ cursor_path: Path::new("test.rs").into(),
+ cursor_excerpt: cursor_excerpt.into(),
+ editable_range_in_excerpt: editable_range,
+ cursor_offset_in_excerpt: cursor_offset,
+ events: events.into_iter().map(Arc::new).collect(),
+ related_files,
+ }
+ }
+
+ fn make_event(path: &str, diff: &str) -> Event {
+ Event::BufferChange {
+ path: Path::new(path).into(),
+ old_path: Path::new(path).into(),
+ diff: diff.to_string(),
+ predicted: false,
+ in_open_source_repo: false,
+ }
+ }
+
+ fn make_related_file(path: &str, content: &str) -> RelatedFile {
+ RelatedFile {
+ path: Path::new(path).into(),
+ max_row: content.lines().count() as u32,
+ excerpts: vec![RelatedExcerpt {
+ row_range: 0..content.lines().count() as u32,
+ text: content.into(),
+ }],
+ }
+ }
+
+ fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
+ format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
+ }
+
+ #[test]
+ fn test_no_truncation_when_within_budget() {
+ let input = make_input(
+ "prefix\neditable\nsuffix",
+ 7..15,
+ 10,
+ vec![make_event("a.rs", "-old\n+new\n")],
+ vec![make_related_file("related.rs", "fn helper() {}\n")],
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 10000),
+ indoc! {r#"
+ <|file_sep|>related.rs
+ fn helper() {}
+ <|file_sep|>edit history
+ --- a/a.rs
+ +++ b/a.rs
+ -old
+ +new
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ prefix
+ <|fim_middle|>current
+ edi<|user_cursor|>table
+ <|fim_suffix|>
+
+ suffix
+ <|fim_middle|>updated
+ "#}
+ );
+ }
+
+ #[test]
+ fn test_truncation_drops_edit_history_when_budget_tight() {
+ let input = make_input(
+ "code",
+ 0..4,
+ 2,
+ vec![make_event("a.rs", "-x\n+y\n")],
+ vec![
+ make_related_file("r1.rs", "a\n"),
+ make_related_file("r2.rs", "b\n"),
+ ],
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 10000),
+ indoc! {r#"
+ <|file_sep|>r1.rs
+ a
+ <|file_sep|>r2.rs
+ b
+ <|file_sep|>edit history
+ --- a/a.rs
+ +++ b/a.rs
+ -x
+ +y
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ co<|user_cursor|>de
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 50),
+ indoc! {r#"
+ <|file_sep|>r1.rs
+ a
+ <|file_sep|>r2.rs
+ b
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ co<|user_cursor|>de
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+ }
+
+ #[test]
+ fn test_truncation_includes_partial_excerpts() {
+ let input = make_input(
+ "x",
+ 0..1,
+ 0,
+ vec![],
+ vec![RelatedFile {
+ path: Path::new("big.rs").into(),
+ max_row: 30,
+ excerpts: vec![
+ RelatedExcerpt {
+ row_range: 0..10,
+ text: "first excerpt\n".into(),
+ },
+ RelatedExcerpt {
+ row_range: 10..20,
+ text: "second excerpt\n".into(),
+ },
+ RelatedExcerpt {
+ row_range: 20..30,
+ text: "third excerpt\n".into(),
+ },
+ ],
+ }],
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 10000),
+ indoc! {r#"
+ <|file_sep|>big.rs
+ first excerpt
+ ...
+ second excerpt
+ ...
+ third excerpt
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ <|user_cursor|>x
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 50),
+ indoc! {r#"
+ <|file_sep|>big.rs
+ first excerpt
+ ...
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ <|user_cursor|>x
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+ }
+
+ #[test]
+ fn test_truncation_drops_older_events_first() {
+ let input = make_input(
+ "x",
+ 0..1,
+ 0,
+ vec![make_event("old.rs", "-1\n"), make_event("new.rs", "-2\n")],
+ vec![],
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 10000),
+ indoc! {r#"
+ <|file_sep|>edit history
+ --- a/old.rs
+ +++ b/old.rs
+ -1
+ --- a/new.rs
+ +++ b/new.rs
+ -2
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ <|user_cursor|>x
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 55),
+ indoc! {r#"
+ <|file_sep|>edit history
+ --- a/new.rs
+ +++ b/new.rs
+ -2
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ <|user_cursor|>x
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+ }
+
+ #[test]
+ fn test_cursor_excerpt_always_included_with_minimal_budget() {
+ let input = make_input(
+ "fn main() {}",
+ 0..12,
+ 3,
+ vec![make_event("a.rs", "-old\n+new\n")],
+ vec![make_related_file("related.rs", "helper\n")],
+ );
+
+ assert_eq!(
+ format_with_budget(&input, 30),
+ indoc! {r#"
+ <|file_sep|>test.rs
+ <|fim_prefix|>
+ <|fim_middle|>current
+ fn <|user_cursor|>main() {}
+ <|fim_suffix|>
+ <|fim_middle|>updated
+ "#}
+ );
+ }
+}