From 376d410b83cd20acd156230fa55ca003e55da301 Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 10 Mar 2026 13:43:27 +0200 Subject: [PATCH] ep: Add multi-region format (#51185) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This format generates fewer token while maintaining the quality: ``` Model Generated tokens ↓ DeltaChrF ↑ 0306-seed-multi-regions 46,239 80.62 0304-seed-no-edits 110,871 80.61 0303-seed 271,457 79.62 ``` In addition to the student format, this change adds a new teacher prompt. It seems to be worse than the original, but I haven't optimized it at all. Keeping it for now as a base for potential improvements. Release Notes: - N/A --- .../edit_prediction_cli/src/format_prompt.rs | 270 ++++++++- crates/edit_prediction_cli/src/main.rs | 62 +- .../edit_prediction_cli/src/parse_output.rs | 6 +- crates/edit_prediction_cli/src/predict.rs | 96 ++- .../src/prompts/teacher_multi_region.md | 366 ++++++++++++ crates/zeta_prompt/src/multi_region.rs | 557 ++++++++++++++++++ crates/zeta_prompt/src/zeta_prompt.rs | 110 +++- 7 files changed, 1442 insertions(+), 25 deletions(-) create mode 100644 crates/edit_prediction_cli/src/prompts/teacher_multi_region.md create mode 100644 crates/zeta_prompt/src/multi_region.rs diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index 324c297ba4c75d10a24b53c7961bd35e1f42e2cd..af955a05dce01fd34c37eb55d15b76b4a4592745 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -13,7 +13,7 @@ use std::ops::Range; use std::sync::Arc; use zeta_prompt::{ ZetaFormat, encode_patch_as_output_for_format, excerpt_range_for_format, format_zeta_prompt, - output_end_marker_for_format, resolve_cursor_region, + multi_region, output_end_marker_for_format, resolve_cursor_region, }; pub async fn run_format_prompt( @@ -49,6 +49,24 @@ pub async fn run_format_prompt( provider: args.provider, }); } + PredictionProvider::TeacherMultiRegion(_) + | PredictionProvider::TeacherMultiRegionNonBatching(_) => { + step_progress.set_substatus("formatting teacher multi-region prompt"); + + let zeta_format = ZetaFormat::default(); + let (editable_range, context_range) = + excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); + + let prompt = + TeacherMultiRegionPrompt::format_prompt(example, editable_range, context_range); + example.prompt = Some(ExamplePrompt { + input: prompt, + expected_output: String::new(), + rejected_output: None, + prefill: None, + provider: args.provider, + }); + } PredictionProvider::Zeta2(zeta_format) => { step_progress.set_substatus("formatting zeta2 prompt"); @@ -108,7 +126,7 @@ pub fn zeta2_output_for_patch( return Ok(encoded_output); } - let (mut result, first_hunk_offset) = + let (result, first_hunk_offset) = udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context( || { format!( @@ -118,6 +136,22 @@ pub fn zeta2_output_for_patch( }, )?; + if version == ZetaFormat::V0306SeedMultiRegions { + let cursor_in_new = cursor_offset.map(|cursor_offset| { + let hunk_start = first_hunk_offset.unwrap_or(0); + result.floor_char_boundary((hunk_start + cursor_offset).min(result.len())) + }); + return multi_region::encode_from_old_and_new( + &old_editable_region, + &result, + cursor_in_new, + zeta_prompt::CURSOR_MARKER, + zeta_prompt::seed_coder::END_MARKER, + zeta_prompt::seed_coder::NO_EDITS, + ); + } + + let mut result = result; if let Some(cursor_offset) = cursor_offset { // The cursor_offset is relative to the start of the hunk's new text (context + additions). // We need to add where the hunk context matched in the editable region to compute @@ -211,7 +245,6 @@ impl TeacherPrompt { .context("editable region not found in prompt content")?; let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); - // Use full context so cursor offset (relative to editable region start) aligns with diff content let editable_region_lines = old_editable_region.lines().count() as u32; let diff = language::unified_diff_with_context( &old_editable_region, @@ -263,6 +296,7 @@ impl TeacherPrompt { .prompt_inputs .as_ref() .and_then(|pi| pi.related_files.as_deref()); + let Some(related_files) = related_files else { return "(No context)".to_string(); }; @@ -317,6 +351,202 @@ impl TeacherPrompt { } } +pub struct TeacherMultiRegionPrompt; + +impl TeacherMultiRegionPrompt { + pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>"; + pub(crate) const NO_EDITS: &str = "NO_EDITS"; + + /// Truncate edit history to this number of last lines + const MAX_HISTORY_LINES: usize = 128; + + pub fn format_prompt( + example: &Example, + editable_range: Range, + context_range: Range, + ) -> String { + let edit_history = Self::format_edit_history(&example.spec.edit_history); + let context = Self::format_context(example); + let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range); + + let prompt_template = crate::prompt_assets::get_prompt("teacher_multi_region.md"); + let prompt = prompt_template + .replace("{{context}}", &context) + .replace("{{edit_history}}", &edit_history) + .replace("{{cursor_excerpt}}", &cursor_excerpt); + + prompt + } + + pub fn parse(example: &Example, response: &str) -> Result<(String, Option)> { + let no_edits = (String::new(), None); + if let Some(last_codeblock) = extract_last_codeblock(&response) { + if last_codeblock.trim() == Self::NO_EDITS { + return Ok(no_edits); + } + } + + if response.trim().ends_with(Self::NO_EDITS) { + return Ok(no_edits); + } + + let prompt_inputs = example + .prompt_inputs + .as_ref() + .context("example is missing prompt inputs")?; + + let zeta_format = ZetaFormat::default(); + let (editable_range, _) = + excerpt_range_for_format(zeta_format, &prompt_inputs.excerpt_ranges); + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let old_editable_region = &excerpt[editable_range.clone()]; + let marker_offsets = multi_region::compute_marker_offsets(old_editable_region); + + let codeblock = + extract_last_codeblock(&response).context("no codeblock found in model response")?; + let (start_num, end_num, raw_new_span) = multi_region::extract_marker_span(&codeblock)?; + + let start_idx = start_num + .checked_sub(1) + .context("marker numbers are 1-indexed")?; + let end_idx = end_num + .checked_sub(1) + .context("marker numbers are 1-indexed")?; + let start_byte = *marker_offsets + .get(start_idx) + .context("start marker number out of range")?; + let end_byte = *marker_offsets + .get(end_idx) + .context("end marker number out of range")?; + + if start_byte > end_byte { + return Err(anyhow!("start marker must come before end marker")); + } + + let cursor_in_span = raw_new_span.find(Self::USER_CURSOR_MARKER); + let new_span = raw_new_span.replace(Self::USER_CURSOR_MARKER, ""); + + let old_span = &old_editable_region[start_byte..end_byte]; + let mut new_span = new_span; + if old_span.ends_with('\n') && !new_span.ends_with('\n') && !new_span.is_empty() { + new_span.push('\n'); + } + if !old_span.ends_with('\n') && new_span.ends_with('\n') { + new_span.pop(); + } + + let mut new_editable_region = String::new(); + new_editable_region.push_str(&old_editable_region[..start_byte]); + new_editable_region.push_str(&new_span); + new_editable_region.push_str(&old_editable_region[end_byte..]); + + let cursor_offset = cursor_in_span.map(|pos| start_byte + pos); + + if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') { + new_editable_region.insert(0, '\n'); + } + + let editable_region_offset = editable_range.start; + let editable_region_start_line = excerpt[..editable_region_offset].matches('\n').count(); + + let editable_region_lines = old_editable_region.lines().count() as u32; + let diff = language::unified_diff_with_context( + old_editable_region, + &new_editable_region, + editable_region_start_line as u32, + editable_region_start_line as u32, + editable_region_lines, + ); + + let diff = indoc::formatdoc! {" + --- a/{path} + +++ b/{path} + {diff}", + path = example.spec.cursor_path.to_string_lossy(), + diff = diff, + }; + + let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| { + ActualCursor::from_editable_region( + &example.spec.cursor_path, + editable_region_cursor_offset, + &new_editable_region, + excerpt, + editable_region_offset, + editable_region_start_line, + ) + }); + + Ok((diff, actual_cursor)) + } + + fn format_edit_history(edit_history: &str) -> String { + let lines: Vec<&str> = edit_history.lines().collect(); + + if lines.is_empty() { + return "(No edit history)".to_string(); + } + + if lines.len() > Self::MAX_HISTORY_LINES { + let truncated = lines[lines.len() - Self::MAX_HISTORY_LINES..].join("\n"); + format!("{truncated}\n[...truncated...]") + } else { + lines.join("\n") + } + } + + pub fn format_context(example: &Example) -> String { + let related_files = example + .prompt_inputs + .as_ref() + .and_then(|pi| pi.related_files.as_deref()); + let Some(related_files) = related_files else { + return "(No context)".to_string(); + }; + + if related_files.is_empty() { + return "(No context)".to_string(); + } + + let prefix = "`````"; + let suffix = "`````\n\n"; + let max_tokens = 1024; + zeta_prompt::format_related_files_within_budget(related_files, &prefix, &suffix, max_tokens) + } + + fn format_cursor_excerpt( + example: &Example, + editable_range: Range, + context_range: Range, + ) -> String { + let mut result = String::new(); + + let prompt_inputs = example.prompt_inputs.as_ref().unwrap(); + let excerpt = prompt_inputs.cursor_excerpt.as_ref(); + let cursor_offset = prompt_inputs.cursor_offset_in_excerpt; + + let editable_text = &excerpt[editable_range.clone()]; + let cursor_in_editable = cursor_offset - editable_range.start; + + let path_str = example.spec.cursor_path.to_string_lossy(); + result.push_str(&format!("`````{path_str}\n")); + + result.push_str(&excerpt[context_range.start..editable_range.start]); + + multi_region::write_editable_with_markers( + &mut result, + editable_text, + cursor_in_editable, + Self::USER_CURSOR_MARKER, + ); + + result.push_str(&excerpt[editable_range.end..context_range.end]); + result.push_str("\n`````"); + + result + } +} + /// Extract the cursor excerpt from an example. /// First tries to extract from an existing prompt, then falls back to constructing from prompt_inputs. pub fn extract_cursor_excerpt_from_example(example: &Example) -> Option { @@ -461,7 +691,7 @@ mod tests { } #[test] - fn test_extract_editable_region() { + fn test_extract_editable_region_old_format() { let text = indoc::indoc! {" some lines are @@ -483,6 +713,38 @@ mod tests { ); } + #[test] + fn test_extract_editable_region_marker_format() { + let text = indoc::indoc! {" + some context + <|marker_1|> + one + two three + <|marker_2|> + more context + "}; + let parsed = multi_region::extract_editable_region_from_markers(text).unwrap(); + assert_eq!(parsed, "one\ntwo three"); + } + + #[test] + fn test_extract_editable_region_multi_markers() { + let text = indoc::indoc! {" + prefix + <|marker_1|> + aaa + bbb + <|marker_2|> + ccc + ddd + <|marker_3|> + suffix + "}; + let parsed = multi_region::extract_editable_region_from_markers(text).unwrap(); + // Intermediate marker and its trailing \n are stripped + assert_eq!(parsed, "aaa\nbbb\nccc\nddd"); + } + #[test] fn test_extract_last_codeblock_nested_bibtex() { let text = indoc::indoc! {r#" diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index afe25c5badcfff03babd5e951ae66839ce0f790b..1dcd1d4aa3ad34df853e9d7b193c246f151a61b2 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -360,7 +360,9 @@ enum PredictionProvider { Zeta2(ZetaFormat), Baseten(ZetaFormat), Teacher(TeacherBackend), + TeacherMultiRegion(TeacherBackend), TeacherNonBatching(TeacherBackend), + TeacherMultiRegionNonBatching(TeacherBackend), Repair, } @@ -379,9 +381,15 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"), PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"), PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), + PredictionProvider::TeacherMultiRegion(backend) => { + write!(f, "teacher-multi-region:{backend}") + } PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") } + PredictionProvider::TeacherMultiRegionNonBatching(backend) => { + write!(f, "teacher-multi-region-non-batching:{backend}") + } PredictionProvider::Repair => write!(f, "repair"), } } @@ -409,13 +417,27 @@ impl std::str::FromStr for PredictionProvider { .unwrap_or(TeacherBackend::default()); Ok(PredictionProvider::Teacher(backend)) } - "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => { + "teacher-multi-region" | "teacher_multi_region" => { + let backend = arg + .map(|a| a.parse()) + .transpose()? + .unwrap_or(TeacherBackend::default()); + Ok(PredictionProvider::TeacherMultiRegion(backend)) + } + "teacher-non-batching" | "teacher_non_batching" => { let backend = arg .map(|a| a.parse()) .transpose()? .unwrap_or(TeacherBackend::default()); Ok(PredictionProvider::TeacherNonBatching(backend)) } + "teacher-multi-region-non-batching" | "teacher_multi_region_non_batching" => { + let backend = arg + .map(|a| a.parse()) + .transpose()? + .unwrap_or(TeacherBackend::default()); + Ok(PredictionProvider::TeacherMultiRegionNonBatching(backend)) + } "repair" => Ok(PredictionProvider::Repair), "baseten" => { let format = arg @@ -426,9 +448,9 @@ impl std::str::FromStr for PredictionProvider { } _ => { anyhow::bail!( - "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repair\n\ + "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-multi-region, teacher-multi-region:, teacher-non-batching, teacher-multi-region-non-batching, repair\n\ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ - For teacher, you can specify a backend like `teacher:sonnet46` or `teacher:gpt52`.\n\ + For teacher providers, you can specify a backend like `teacher:sonnet46`, `teacher-multi-region:sonnet46`, `teacher-multi-region-non-batching:sonnet46`, or `teacher:gpt52`.\n\ Available zeta versions:\n{}", ZetaFormat::options_as_string() ) @@ -491,6 +513,40 @@ enum BatchProvider { Openai, } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prediction_provider_multi_region_non_batched_round_trips_to_primary_spelling() { + let provider: PredictionProvider = "teacher-multi-region-non-batching:sonnet46" + .parse() + .unwrap(); + assert_eq!( + provider, + PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Sonnet46) + ); + assert_eq!( + provider.to_string(), + "teacher-multi-region-non-batching:sonnet46" + ); + } + + #[test] + fn prediction_provider_multi_region_non_batched_alias_round_trips_to_primary_spelling() { + let provider: PredictionProvider = + "teacher_multi_region_non_batching:gpt52".parse().unwrap(); + assert_eq!( + provider, + PredictionProvider::TeacherMultiRegionNonBatching(TeacherBackend::Gpt52) + ); + assert_eq!( + provider.to_string(), + "teacher-multi-region-non-batching:gpt52" + ); + } +} + impl EpArgs { fn output_path(&self) -> Option { if self.in_place { diff --git a/crates/edit_prediction_cli/src/parse_output.rs b/crates/edit_prediction_cli/src/parse_output.rs index 94058efd92ca4a166ba4976819963ef5d3286f5d..2b41384e176ac7a6cc5c3dc7f93ddbba3cf027ae 100644 --- a/crates/edit_prediction_cli/src/parse_output.rs +++ b/crates/edit_prediction_cli/src/parse_output.rs @@ -1,7 +1,7 @@ use crate::{ PredictionProvider, example::{ActualCursor, Example}, - format_prompt::TeacherPrompt, + format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt}, repair, }; use anyhow::{Context as _, Result}; @@ -41,6 +41,10 @@ pub fn parse_prediction_output( PredictionProvider::Teacher(_) | PredictionProvider::TeacherNonBatching(_) => { TeacherPrompt::parse(example, actual_output) } + PredictionProvider::TeacherMultiRegion(_) + | PredictionProvider::TeacherMultiRegionNonBatching(_) => { + TeacherMultiRegionPrompt::parse(example, actual_output) + } PredictionProvider::Zeta2(version) => parse_zeta2_output(example, actual_output, version), PredictionProvider::Repair => repair::parse(example, actual_output), _ => anyhow::bail!( diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 94e28d00da2d61f63b59364304c3b9b4276e15f7..9f70861b5ef7298141441ec09606fa77e341cbfd 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -2,7 +2,7 @@ use crate::{ FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend, anthropic_client::AnthropicClient, example::{Example, ExamplePrediction, ExamplePrompt}, - format_prompt::{TeacherPrompt, run_format_prompt}, + format_prompt::{TeacherMultiRegionPrompt, TeacherPrompt, run_format_prompt}, headless::EpAppState, load_project::run_load_project, openai_client::OpenAiClient, @@ -57,8 +57,10 @@ pub async fn run_prediction( ); }; - if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) = - provider + if let PredictionProvider::Teacher(backend) + | PredictionProvider::TeacherMultiRegion(backend) + | PredictionProvider::TeacherNonBatching(backend) + | PredictionProvider::TeacherMultiRegionNonBatching(backend) = provider { run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; run_format_prompt( @@ -71,7 +73,10 @@ pub async fn run_prediction( .await?; let step_progress = example_progress.start(Step::Predict); - let batched = matches!(provider, PredictionProvider::Teacher(..)); + let batched = matches!( + provider, + PredictionProvider::Teacher(..) | PredictionProvider::TeacherMultiRegion(..) + ); return predict_teacher( example, backend, @@ -135,7 +140,9 @@ pub async fn run_prediction( PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) + | PredictionProvider::TeacherMultiRegion(..) | PredictionProvider::TeacherNonBatching(..) + | PredictionProvider::TeacherMultiRegionNonBatching(..) | PredictionProvider::Repair | PredictionProvider::Baseten(_) => { unreachable!() @@ -403,7 +410,29 @@ async fn predict_anthropic( .collect::>() .join("\n"); - let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?; + let parser_provider = if batched { + example + .prompt + .as_ref() + .map(|prompt| prompt.provider) + .unwrap_or(PredictionProvider::Teacher(backend)) + } else { + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) + | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => { + PredictionProvider::TeacherMultiRegionNonBatching(backend) + } + _ => PredictionProvider::TeacherNonBatching(backend), + } + }; + + let (actual_patch, actual_cursor) = match parser_provider { + PredictionProvider::TeacherMultiRegion(_) + | PredictionProvider::TeacherMultiRegionNonBatching(_) => { + TeacherMultiRegionPrompt::parse(example, &actual_output)? + } + _ => TeacherPrompt::parse(example, &actual_output)?, + }; let prediction = ExamplePrediction { actual_patch: Some(actual_patch), @@ -411,9 +440,20 @@ async fn predict_anthropic( actual_cursor, error: None, provider: if batched { - PredictionProvider::Teacher(backend) + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) => { + PredictionProvider::TeacherMultiRegion(backend) + } + _ => PredictionProvider::Teacher(backend), + } } else { - PredictionProvider::TeacherNonBatching(backend) + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) + | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => { + PredictionProvider::TeacherMultiRegionNonBatching(backend) + } + _ => PredictionProvider::TeacherNonBatching(backend), + } }, }; @@ -487,7 +527,29 @@ async fn predict_openai( .collect::>() .join("\n"); - let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?; + let parser_provider = if batched { + example + .prompt + .as_ref() + .map(|prompt| prompt.provider) + .unwrap_or(PredictionProvider::Teacher(backend)) + } else { + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) + | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => { + PredictionProvider::TeacherMultiRegionNonBatching(backend) + } + _ => PredictionProvider::TeacherNonBatching(backend), + } + }; + + let (actual_patch, actual_cursor) = match parser_provider { + PredictionProvider::TeacherMultiRegion(_) + | PredictionProvider::TeacherMultiRegionNonBatching(_) => { + TeacherMultiRegionPrompt::parse(example, &actual_output)? + } + _ => TeacherPrompt::parse(example, &actual_output)?, + }; let prediction = ExamplePrediction { actual_patch: Some(actual_patch), @@ -495,9 +557,20 @@ async fn predict_openai( actual_cursor, error: None, provider: if batched { - PredictionProvider::Teacher(backend) + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) => { + PredictionProvider::TeacherMultiRegion(backend) + } + _ => PredictionProvider::Teacher(backend), + } } else { - PredictionProvider::TeacherNonBatching(backend) + match example.prompt.as_ref().map(|prompt| prompt.provider) { + Some(PredictionProvider::TeacherMultiRegion(_)) + | Some(PredictionProvider::TeacherMultiRegionNonBatching(_)) => { + PredictionProvider::TeacherMultiRegionNonBatching(backend) + } + _ => PredictionProvider::TeacherNonBatching(backend), + } }, }; @@ -591,7 +664,8 @@ pub async fn predict_baseten( pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { - Some(PredictionProvider::Teacher(backend)) => match backend { + Some(PredictionProvider::Teacher(backend)) + | Some(PredictionProvider::TeacherMultiRegion(backend)) => match backend { TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => { let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) diff --git a/crates/edit_prediction_cli/src/prompts/teacher_multi_region.md b/crates/edit_prediction_cli/src/prompts/teacher_multi_region.md new file mode 100644 index 0000000000000000000000000000000000000000..61c5c8f3837a321cb565d5c2b089eec94fcc3dc5 --- /dev/null +++ b/crates/edit_prediction_cli/src/prompts/teacher_multi_region.md @@ -0,0 +1,366 @@ +# Instructions + +You are an edit prediction assistant in a code editor. Your task is to predict the next edit to a given region of code surrounding the user's cursor. + +1. Analyze the edit history to understand what the programmer is trying to achieve +2. Identify any incomplete refactoring or changes that need to be finished +3. Make the remaining edits that a human programmer would logically make next (by rewriting a region of code near their cursor) + +## Focus on + +- Completing any partially-applied changes made +- Ensuring consistency with the programming style and patterns already established +- Making edits that maintain or improve code quality + +## Rules + +- **NEVER undo or revert the user's recent edits.** Examine the diff in the edit history carefully: + - If a line was removed (starts with `-`), do NOT restore that content—even if the code now appears incomplete or broken without it + - If a line was added (starts with `+`), do NOT delete or significantly modify it + - If code appears broken or incomplete after the user's edit, output `NO_EDITS` rather than "fixing" it by reverting + - Only add NEW content that extends the user's work forward; never restore what they removed + - **Key test**: if your prediction would make the code more similar to what it was BEFORE the user's edit, output `NO_EDITS` instead + - **Never assume a deletion was accidental.** Even if removing content breaks the code, breaks a pattern, or leaves text looking "incomplete", respect it. The user may be mid-rewrite. Do NOT "complete" partial text by restoring what was deleted. +- Auto-generated code can be modified: Hunks marked with `// User accepted prediction:` contain code from a previous prediction the user accepted. Unlike user-typed content, these hunks CAN be edited, corrected, or replaced if it improves the code. The "never undo/revert" rule protects the user's *current typing intent*—auto-generated code doesn't carry this protection +- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. +- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. +- Keep existing formatting unless it's absolutely necessary +- When edit history and surrounding code suggest different edits, prioritize the most recent edits in the history as they best reflect current intent. +- Treat partial text at or near the cursor as the beginning of something the user is actively typing. Complete the code the user appears to be creating based on context. +- When completing partial code, prefer predictions that save meaningful keystrokes, even if this requires making educated guesses about the user's intent. +- For code, it's better to make a substantive prediction that might be rejected than to make a minimal prediction that saves only a few keystrokes. +- When the user is editing prose or documentation (e.g. Markdown, comments, plain text), predict conservatively. Complete the current fragment or sentence, but do not generate additional lines of free-form content since prose is less constrained than code and more prone to incorrect continuations. + +# Input Format + +You will be provided with: +1. The user's *edit history*, in chronological order. Use this to infer the user's trajectory and predict the next most logical edit. + - Hunks preceded by `// User accepted prediction:` indicate code that was auto-generated by a previous prediction and accepted by the user. These are treated differently than user-typed edits (see Rules). +2. A set of *related excerpts* from the user's codebase. Some of these may be needed for correctly predicting the next edit. + - `…` may appear within a related file to indicate that some code has been skipped. +3. An excerpt from the user's *current file*. + - The excerpt contains numbered *marker* tags (`<|marker_1|>`, `<|marker_2|>`, etc.) placed at block boundaries throughout the code. These markers divide the excerpt into spans that you can target for editing. + - Code that appears before the first marker or after the last marker is read-only context and cannot be edited. + - The `<|user_cursor|>` tag marks the user's current cursor position, as it stands after the last edit in the history. + +# Output Format + +- Briefly explain the user's current intent based on the edit history and their current cursor location. +- Output a markdown codeblock containing your predicted edit as a **marker-bounded span**: + - The codeblock must **start** with a marker tag (e.g. `<|marker_2|>`) and **end** with a marker tag (e.g. `<|marker_4|>`). + - The content between these two markers is the full replacement for that span in the original file. + - Choose the **narrowest** pair of markers that fully contains your predicted edits, to minimize unnecessary output. + - Reproduce any unchanged lines within the chosen span faithfully — do not omit or alter them. + - Do not include any intermediate marker tags in your output — only the start and end markers. +- If no edit is needed (the code is already complete and correct, or there is no clear next edit to make), output a codeblock containing only `NO_EDITS`: + ````` + NO_EDITS + ````` +- If there is a specific place in the predicted output where the user is likely to edit next, indicate it using the `<|user_cursor|>` tag. + +## Example 1 + +There is code missing at the cursor location. The related excerpts includes the definition of a relevant type. You should fill in the missing code. + +### Related Excerpts + +````` +struct Product { + name: String, + price: u32, +} +````` + +### User Edit History + +````` +--- a/src/calculate.rs ++++ b/src/calculate.rs +@@ -100,6 +100,7 @@ + fn calculate_total(products: &[Product]) -> u32 { + let mut total = 0; + for product in products { ++ total += ; + } + total + } +````` + +### Current File + +`````src/calculate.rs +fn calculate_total(products: &[Product]) -> u32 { +<|marker_1|> + let mut total = 0; + for product in products { + total += <|user_cursor|>; + } + total +<|marker_2|> +} +````` + +### Output + +The user is computing a sum based on a list of products. The only numeric field on `Product` is `price`, so they must intend to sum the prices. + +````` +<|marker_1|> + let mut total = 0; + for product in products { + total += product.price; + } + total +<|marker_2|> +````` + +## Example 2 + +The user appears to be in the process of typing an eprintln call. Rather than fixing the spelling issue by deleting the newly-inserted content, you must continue the user's trajectory. It's not clear what data they intend to print. You should fill in as much code as is obviously intended, and position the cursor so that the user can fill in the rest. + +### User Edit History + +````` +--- a/src/modal.rs ++++ b/src/modal.rs +@@ -100,4 +100,4 @@ + fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { + modal_state.close(); +- modal_state.dismiss(); ++ eprmodal_state.dismiss(); + } +````` + +### Current File + +`````src/modal.rs +<|marker_1|> +// handle the close button click +fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { +<|marker_2|> + modal_state.close(); + epr<|user_cursor|>modal_state.dismiss(); +} +<|marker_3|> +````` + +### Output + +The user is clearly starting to type `eprintln!()`, however, what they intend to print is not obvious. I should fill in the print call and string literal, with the cursor positioned inside the string literal so the user can print whatever they want. + +````` +<|marker_2|> + modal_state.close(); + eprintln!("<|user_cursor|>"); + modal_state.dismiss(); +} +<|marker_3|> +````` + +## Example 3 + +Here, the user is adding a function. There's no way to tell for sure what the function's name will be. In this situation, you should make a reasonable guess at the function's name and signature, and place the user's cursor in the function body. This way, if you guess correctly, it will save the user a meaningful number of keystrokes, and the file will be left in a coherent state. + +### User Edit History + +````` +--- a/src/modal.rs ++++ b/src/modal.rs +@@ -100,4 +100,4 @@ + fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { + modal_state.close(); + modal_state.dismiss(); + } ++ ++fn + + fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) { +````` + +### Current File + +`````src/modal.rs +// handle the close button click +fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) { + modal_state.close(); +<|marker_1|> + modal_state.dismiss(); +} + +fn<|user_cursor|> + +<|marker_2|> +fn handle_keystroke(modal_state: &mut ModalState, evt: &Event) { + modal_state.begin_edit(); +<|marker_3|> +````` + +### Output + +The user is adding a new function. The existing functions I see are `handle_close_button_click` and `handle_keystroke`, which have similar signatures. One possible function they might be adding is `handle_submit`. + +````` +<|marker_1|> + modal_state.dismiss(); +} + +fn handle_submit(modal_state: &mut ModalState, evt: &Event) { + <|user_cursor|> +} + +<|marker_2|> +````` + +## Example 4 + +The code is already complete and there is no clear next edit to make. You should output NO_EDITS. + +### User Edit History + +````` +--- a/src/utils.rs ++++ b/src/utils.rs +@@ -10,7 +10,7 @@ + fn add(a: i32, b: i32) -> i32 { +- a - b ++ a + b + } +````` + +### Current File + +`````src/utils.rs +<|marker_1|> +fn add(a: i32, b: i32) -> i32 { + a + b<|user_cursor|> +} +<|marker_2|> +````` + +### Output + +The user just fixed a bug in the `add` function, changing subtraction to addition. The code is now correct and complete. There is no clear next edit to make. + +````` +NO_EDITS +````` + +## Example 5 + +The user just deleted code, leaving behind what looks incomplete. You must NOT "complete" it by restoring deleted content—that would undo their edit. Output NO_EDITS. **This is the correct response even though the code appears broken.** + +### User Edit History + +````` +--- a/config.nix ++++ b/config.nix +@@ -10,7 +10,7 @@ + # /etc/modular/crashdb needs to be mutable +- ln -s /tmp/crashdb $out/etc/modular/crashdb ++ ln -s /tmp/cr $out/etc/modular/crashdb + ''; +````` + +### Current File + +`````config.nix +<|marker_1|> + # /etc/modular/crashdb needs to be mutable + ln -s /tmp/cr<|user_cursor|> $out/etc/modular/crashdb + ''; +<|marker_2|> +````` + +### Output + +The user deleted `ashdb` from `/tmp/crashdb`, leaving `/tmp/cr`. Although this looks like incomplete text that I could "complete", doing so would restore deleted content. The user intentionally removed that text—I must not undo their deletion. + +````` +NO_EDITS +````` + +## Example 6 + +The user accepted a prediction for a function, then started renaming it. The original arguments were auto-generated (marked with `// User accepted prediction:`), so they CAN be updated to match the new function name. This is NOT reverting user input—it's improving auto-generated scaffolding. + +### User Edit History + +````` +--- a/math_utils.py ++++ b/math_utils.py +@@ -3,3 +3,5 @@ + def calculate_rectangle_area(width, height): + return width * height + + ++de + +// User accepted prediction: +--- a/math_utils.py ++++ b/math_utils.py +@@ -3,5 +3,7 @@ + def calculate_rectangle_area(width, height): + return width * height + +-de ++def calculate_rectangle_perimeter(width, height): ++ + +--- a/math_utils.py ++++ b/math_utils.py +@@ -5,5 +5,5 @@ + return width * height + +-def calculate_rectangle_perimeter(width, height): ++def calculate_sq_perimeter(width, height): + +````` + +### Current File + +`````math_utils.py +<|marker_1|> +def calculate_rectangle_area(width, height): + return width * height + +<|marker_2|> +def calculate_sq<|user_cursor|>_perimeter(width, height): + +<|marker_3|> +````` + +### Output + +The user accepted a prediction for `calculate_rectangle_perimeter(width, height)`, then started renaming `rectangle` to `square`. Since squares have equal sides, the arguments should change from `(width, height)` to `(side)`. The arguments were auto-generated (from an accepted prediction), so modifying them is appropriate. + +````` +<|marker_2|> +def calculate_square_perimeter(side): + <|user_cursor|> +<|marker_3|> +````` + + + +# Your task: + +# 1. User Edit History + +````` +{{edit_history}} +````` + +# 2. Related excerpts + +{{context}} + +# 3. Current File + +{{cursor_excerpt}} + + + + +----- + +Based on the edit history and context above, predict the user's next edit within the marker-bounded spans. diff --git a/crates/zeta_prompt/src/multi_region.rs b/crates/zeta_prompt/src/multi_region.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bac794b1d71fdf5ca8e086b748b8aa426bad1bd --- /dev/null +++ b/crates/zeta_prompt/src/multi_region.rs @@ -0,0 +1,557 @@ +use anyhow::{Context as _, Result, anyhow}; + +pub const MARKER_TAG_PREFIX: &str = "<|marker_"; +pub const MARKER_TAG_SUFFIX: &str = "|>"; +const MIN_BLOCK_LINES: usize = 3; +const MAX_BLOCK_LINES: usize = 8; + +pub fn marker_tag(number: usize) -> String { + format!("{MARKER_TAG_PREFIX}{number}{MARKER_TAG_SUFFIX}") +} + +/// Compute byte offsets within `editable_text` where marker boundaries should +/// be placed. +/// +/// Returns a sorted `Vec` that always starts with `0` and ends with +/// `editable_text.len()`. Interior offsets are placed at line boundaries +/// (right after a `\n`), preferring blank-line boundaries when available and +/// respecting `MIN_BLOCK_LINES` / `MAX_BLOCK_LINES` constraints. +pub fn compute_marker_offsets(editable_text: &str) -> Vec { + if editable_text.is_empty() { + return vec![0, 0]; + } + + let mut offsets = vec![0usize]; + let mut lines_since_last_marker = 0usize; + let mut byte_offset = 0usize; + + for line in editable_text.split('\n') { + let line_end = byte_offset + line.len() + 1; + let is_past_end = line_end > editable_text.len(); + let actual_line_end = line_end.min(editable_text.len()); + lines_since_last_marker += 1; + + let is_blank = line.trim().is_empty(); + + if !is_past_end && lines_since_last_marker >= MIN_BLOCK_LINES { + if is_blank { + // Blank-line boundary found. We'll place the marker when we + // find the next non-blank line (handled below). + } else if lines_since_last_marker >= MAX_BLOCK_LINES { + offsets.push(actual_line_end); + lines_since_last_marker = 0; + } + } + + // Non-blank line immediately following blank line(s): split here so + // the new block starts with this line. + if !is_blank && byte_offset > 0 && lines_since_last_marker >= MIN_BLOCK_LINES { + let before = &editable_text[..byte_offset]; + let has_preceding_blank_line = before + .strip_suffix('\n') + .map(|stripped| { + let last_line = match stripped.rfind('\n') { + Some(pos) => &stripped[pos + 1..], + None => stripped, + }; + last_line.trim().is_empty() + }) + .unwrap_or(false); + + if has_preceding_blank_line { + offsets.push(byte_offset); + lines_since_last_marker = 1; + } + } + + byte_offset = actual_line_end; + + // Re-check after blank-line logic since lines_since_last_marker may + // have been reset. + if !is_past_end && lines_since_last_marker >= MAX_BLOCK_LINES { + if *offsets.last().unwrap_or(&0) != actual_line_end { + offsets.push(actual_line_end); + lines_since_last_marker = 0; + } + } + } + + let end = editable_text.len(); + if *offsets.last().unwrap_or(&0) != end { + offsets.push(end); + } + + offsets +} + +/// Write the editable region content with marker tags, inserting the cursor +/// marker at the given offset within the editable text. +pub fn write_editable_with_markers( + output: &mut String, + editable_text: &str, + cursor_offset_in_editable: usize, + cursor_marker: &str, +) { + let marker_offsets = compute_marker_offsets(editable_text); + let mut cursor_placed = false; + for (i, &offset) in marker_offsets.iter().enumerate() { + let marker_num = i + 1; + if !output.is_empty() && !output.ends_with('\n') { + output.push('\n'); + } + output.push_str(&marker_tag(marker_num)); + + if let Some(&next_offset) = marker_offsets.get(i + 1) { + output.push('\n'); + let block = &editable_text[offset..next_offset]; + if !cursor_placed + && cursor_offset_in_editable >= offset + && cursor_offset_in_editable <= next_offset + { + cursor_placed = true; + let cursor_in_block = cursor_offset_in_editable - offset; + output.push_str(&block[..cursor_in_block]); + output.push_str(cursor_marker); + output.push_str(&block[cursor_in_block..]); + } else { + output.push_str(block); + } + } + } +} + +/// Strip any `<|marker_N|>` tags from `text`. +/// +/// When a marker tag sits on its own line (followed by `\n`), the trailing +/// newline is also removed so the surrounding lines stay joined naturally. +fn strip_marker_tags(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut pos = 0; + let bytes = text.as_bytes(); + while let Some(rel) = text[pos..].find(MARKER_TAG_PREFIX) { + result.push_str(&text[pos..pos + rel]); + let num_start = pos + rel + MARKER_TAG_PREFIX.len(); + if let Some(suffix_rel) = text[num_start..].find(MARKER_TAG_SUFFIX) { + let mut tag_end = num_start + suffix_rel + MARKER_TAG_SUFFIX.len(); + if bytes.get(tag_end) == Some(&b'\n') { + tag_end += 1; + } + pos = tag_end; + } else { + result.push_str(MARKER_TAG_PREFIX); + pos = num_start; + } + } + result.push_str(&text[pos..]); + result +} + +/// Parse model output that uses the marker format. +/// +/// Returns `(start_marker_num, end_marker_num, content_between_markers)`. +/// The leading format-level newline after the start marker is stripped. +/// Trailing newlines are preserved so blank-line endings in the editable +/// region are not lost. +/// +/// Any extra intermediate marker tags that the model may have inserted +/// between the first and last markers are stripped from the returned content. +pub fn extract_marker_span(text: &str) -> Result<(usize, usize, String)> { + let first_tag_start = text + .find(MARKER_TAG_PREFIX) + .context("no start marker found in output")?; + let first_num_start = first_tag_start + MARKER_TAG_PREFIX.len(); + let first_num_end = text[first_num_start..] + .find(MARKER_TAG_SUFFIX) + .map(|i| i + first_num_start) + .context("malformed start marker tag")?; + let start_num: usize = text[first_num_start..first_num_end] + .parse() + .context("start marker number is not a valid integer")?; + let first_tag_end = first_num_end + MARKER_TAG_SUFFIX.len(); + + let last_tag_start = text + .rfind(MARKER_TAG_PREFIX) + .context("no end marker found in output")?; + let last_num_start = last_tag_start + MARKER_TAG_PREFIX.len(); + let last_num_end = text[last_num_start..] + .find(MARKER_TAG_SUFFIX) + .map(|i| i + last_num_start) + .context("malformed end marker tag")?; + let end_num: usize = text[last_num_start..last_num_end] + .parse() + .context("end marker number is not a valid integer")?; + + if start_num == end_num { + return Err(anyhow!( + "start and end markers are the same (marker {})", + start_num + )); + } + + let mut content_start = first_tag_end; + if text.as_bytes().get(content_start) == Some(&b'\n') { + content_start += 1; + } + let content_end = last_tag_start; + + let content = &text[content_start..content_end.max(content_start)]; + let content = strip_marker_tags(content); + Ok((start_num, end_num, content)) +} + +/// Given old editable text and model output with marker span, reconstruct the +/// full new editable region. +pub fn apply_marker_span(old_editable: &str, output: &str) -> Result { + let (start_num, end_num, raw_new_span) = extract_marker_span(output)?; + let marker_offsets = compute_marker_offsets(old_editable); + + let start_idx = start_num + .checked_sub(1) + .context("marker numbers are 1-indexed")?; + let end_idx = end_num + .checked_sub(1) + .context("marker numbers are 1-indexed")?; + let start_byte = *marker_offsets + .get(start_idx) + .context("start marker number out of range")?; + let end_byte = *marker_offsets + .get(end_idx) + .context("end marker number out of range")?; + + if start_byte > end_byte { + return Err(anyhow!("start marker must come before end marker")); + } + + let old_span = &old_editable[start_byte..end_byte]; + let mut new_span = raw_new_span; + if old_span.ends_with('\n') && !new_span.ends_with('\n') && !new_span.is_empty() { + new_span.push('\n'); + } + if !old_span.ends_with('\n') && new_span.ends_with('\n') { + new_span.pop(); + } + + let mut result = String::new(); + result.push_str(&old_editable[..start_byte]); + result.push_str(&new_span); + result.push_str(&old_editable[end_byte..]); + + Ok(result) +} + +/// Compare old and new editable text, find the minimal marker span that covers +/// all changes, and encode the result with marker tags. +pub fn encode_from_old_and_new( + old_editable: &str, + new_editable: &str, + cursor_offset_in_new: Option, + cursor_marker: &str, + end_marker: &str, + no_edits_marker: &str, +) -> Result { + if old_editable == new_editable { + return Ok(format!("{no_edits_marker}{end_marker}")); + } + + let marker_offsets = compute_marker_offsets(old_editable); + + let common_prefix = old_editable + .bytes() + .zip(new_editable.bytes()) + .take_while(|(a, b)| a == b) + .count(); + + let old_remaining = old_editable.len() - common_prefix; + let new_remaining = new_editable.len() - common_prefix; + let max_suffix = old_remaining.min(new_remaining); + let common_suffix = old_editable.as_bytes()[old_editable.len() - max_suffix..] + .iter() + .rev() + .zip( + new_editable.as_bytes()[new_editable.len() - max_suffix..] + .iter() + .rev(), + ) + .take_while(|(a, b)| a == b) + .count(); + + let change_end_in_old = old_editable.len() - common_suffix; + + let start_marker_idx = marker_offsets + .iter() + .rposition(|&offset| offset <= common_prefix) + .unwrap_or(0); + let end_marker_idx = marker_offsets + .iter() + .position(|&offset| offset >= change_end_in_old) + .unwrap_or(marker_offsets.len() - 1); + + let old_start = marker_offsets[start_marker_idx]; + let old_end = marker_offsets[end_marker_idx]; + + let new_start = old_start; + let new_end = new_editable + .len() + .saturating_sub(old_editable.len().saturating_sub(old_end)); + + let new_span = &new_editable[new_start..new_end]; + + let start_marker_num = start_marker_idx + 1; + let end_marker_num = end_marker_idx + 1; + + let mut result = String::new(); + result.push_str(&marker_tag(start_marker_num)); + result.push('\n'); + + if let Some(cursor_offset) = cursor_offset_in_new { + if cursor_offset >= new_start && cursor_offset <= new_end { + let cursor_in_span = cursor_offset - new_start; + let bounded = cursor_in_span.min(new_span.len()); + result.push_str(&new_span[..bounded]); + result.push_str(cursor_marker); + result.push_str(&new_span[bounded..]); + } else { + result.push_str(new_span); + } + } else { + result.push_str(new_span); + } + + if !result.ends_with('\n') { + result.push('\n'); + } + result.push_str(&marker_tag(end_marker_num)); + result.push('\n'); + result.push_str(end_marker); + + Ok(result) +} + +/// Extract the full editable region from text that uses marker tags. +/// +/// Returns the concatenation of all block contents between the first and last +/// markers, with intermediate marker tags stripped. +pub fn extract_editable_region_from_markers(text: &str) -> Option { + let first_marker_start = text.find(MARKER_TAG_PREFIX)?; + + let mut markers: Vec<(usize, usize)> = Vec::new(); + let mut search_start = first_marker_start; + while let Some(rel_pos) = text[search_start..].find(MARKER_TAG_PREFIX) { + let tag_start = search_start + rel_pos; + let num_start = tag_start + MARKER_TAG_PREFIX.len(); + let num_end = text[num_start..].find(MARKER_TAG_SUFFIX)?; + let tag_end = num_start + num_end + MARKER_TAG_SUFFIX.len(); + markers.push((tag_start, tag_end)); + search_start = tag_end; + } + + if markers.len() < 2 { + return None; + } + + let (_, first_tag_end) = markers[0]; + let (last_tag_start, _) = markers[markers.len() - 1]; + + let mut content_start = first_tag_end; + if text.as_bytes().get(content_start) == Some(&b'\n') { + content_start += 1; + } + let mut content_end = last_tag_start; + if content_end > content_start && text.as_bytes().get(content_end - 1) == Some(&b'\n') { + content_end -= 1; + } + + let raw = &text[content_start..content_end]; + let result = strip_marker_tags(raw); + let result = result.strip_suffix('\n').unwrap_or(&result).to_string(); + Some(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_marker_offsets_small_block() { + let text = "aaa\nbbb\nccc\n"; + let offsets = compute_marker_offsets(text); + assert_eq!(offsets, vec![0, text.len()]); + } + + #[test] + fn test_compute_marker_offsets_blank_line_split() { + let text = "aaa\nbbb\nccc\n\nddd\neee\nfff\n"; + let offsets = compute_marker_offsets(text); + assert_eq!(offsets[0], 0); + assert!(offsets.contains(&13), "offsets: {:?}", offsets); + assert_eq!(*offsets.last().unwrap(), text.len()); + } + + #[test] + fn test_compute_marker_offsets_max_lines_split() { + let text = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n"; + let offsets = compute_marker_offsets(text); + assert!(offsets.len() >= 3, "offsets: {:?}", offsets); + } + + #[test] + fn test_compute_marker_offsets_empty() { + let offsets = compute_marker_offsets(""); + assert_eq!(offsets, vec![0, 0]); + } + + #[test] + fn test_extract_marker_span() { + let text = "<|marker_2|>\n new content\n<|marker_3|>\n"; + let (start, end, content) = extract_marker_span(text).unwrap(); + assert_eq!(start, 2); + assert_eq!(end, 3); + assert_eq!(content, " new content\n"); + } + + #[test] + fn test_extract_marker_span_multi_line() { + let text = "<|marker_1|>\nline1\nline2\nline3\n<|marker_4|>"; + let (start, end, content) = extract_marker_span(text).unwrap(); + assert_eq!(start, 1); + assert_eq!(end, 4); + assert_eq!(content, "line1\nline2\nline3\n"); + } + + #[test] + fn test_apply_marker_span_basic() { + let old = "aaa\nbbb\nccc\n"; + let output = "<|marker_1|>\naaa\nBBB\nccc\n<|marker_2|>"; + let result = apply_marker_span(old, output).unwrap(); + assert_eq!(result, "aaa\nBBB\nccc\n"); + } + + #[test] + fn test_apply_marker_span_preserves_trailing_blank_line() { + let old = "/\nresult\n\n"; + let output = "<|marker_1|>\n//\nresult\n\n<|marker_2|>"; + let result = apply_marker_span(old, output).unwrap(); + assert_eq!(result, "//\nresult\n\n"); + } + + #[test] + fn test_encode_no_edits() { + let old = "aaa\nbbb\nccc\n"; + let result = encode_from_old_and_new( + old, + old, + None, + "<|user_cursor|>", + ">>>>>>> UPDATED\n", + "NO_EDITS\n", + ) + .unwrap(); + assert_eq!(result, "NO_EDITS\n>>>>>>> UPDATED\n"); + } + + #[test] + fn test_encode_with_change() { + let old = "aaa\nbbb\nccc\n"; + let new = "aaa\nBBB\nccc\n"; + let result = encode_from_old_and_new( + old, + new, + None, + "<|user_cursor|>", + ">>>>>>> UPDATED\n", + "NO_EDITS\n", + ) + .unwrap(); + assert!(result.contains("<|marker_1|>")); + assert!(result.contains("<|marker_2|>")); + assert!(result.contains("aaa\nBBB\nccc\n")); + assert!(result.ends_with(">>>>>>> UPDATED\n")); + } + + #[test] + fn test_roundtrip_encode_apply() { + let old = "line1\nline2\nline3\n\nline5\nline6\nline7\nline8\nline9\nline10\n"; + let new = "line1\nline2\nline3\n\nline5\nLINE6\nline7\nline8\nline9\nline10\n"; + let encoded = encode_from_old_and_new( + old, + new, + None, + "<|user_cursor|>", + ">>>>>>> UPDATED\n", + "NO_EDITS\n", + ) + .unwrap(); + let output = encoded + .strip_suffix(">>>>>>> UPDATED\n") + .expect("should have end marker"); + let reconstructed = apply_marker_span(old, output).unwrap(); + assert_eq!(reconstructed, new); + } + + #[test] + fn test_extract_editable_region_from_markers_multi() { + let text = "prefix\n<|marker_1|>\naaa\nbbb\n<|marker_2|>\nccc\nddd\n<|marker_3|>\nsuffix"; + let parsed = extract_editable_region_from_markers(text).unwrap(); + assert_eq!(parsed, "aaa\nbbb\nccc\nddd"); + } + + #[test] + fn test_extract_editable_region_two_markers() { + let text = "<|marker_1|>\none\ntwo three\n<|marker_2|>"; + let parsed = extract_editable_region_from_markers(text).unwrap(); + assert_eq!(parsed, "one\ntwo three"); + } + + #[test] + fn test_encode_with_cursor() { + let old = "aaa\nbbb\nccc\n"; + let new = "aaa\nBBB\nccc\n"; + let result = encode_from_old_and_new( + old, + new, + Some(5), + "<|user_cursor|>", + ">>>>>>> UPDATED\n", + "NO_EDITS\n", + ) + .unwrap(); + assert!(result.contains("<|user_cursor|>"), "result: {result}"); + assert!(result.contains("B<|user_cursor|>BB"), "result: {result}"); + } + + #[test] + fn test_extract_marker_span_strips_intermediate_markers() { + let text = "<|marker_2|>\nline1\n<|marker_3|>\nline2\n<|marker_4|>"; + let (start, end, content) = extract_marker_span(text).unwrap(); + assert_eq!(start, 2); + assert_eq!(end, 4); + assert_eq!(content, "line1\nline2\n"); + } + + #[test] + fn test_extract_marker_span_strips_multiple_intermediate_markers() { + let text = "<|marker_1|>\naaa\n<|marker_2|>\nbbb\n<|marker_3|>\nccc\n<|marker_4|>"; + let (start, end, content) = extract_marker_span(text).unwrap(); + assert_eq!(start, 1); + assert_eq!(end, 4); + assert_eq!(content, "aaa\nbbb\nccc\n"); + } + + #[test] + fn test_apply_marker_span_with_extra_intermediate_marker() { + let old = "aaa\nbbb\nccc\n"; + let output = "<|marker_1|>\naaa\n<|marker_1|>\nBBB\nccc\n<|marker_2|>"; + let result = apply_marker_span(old, output).unwrap(); + assert_eq!(result, "aaa\nBBB\nccc\n"); + } + + #[test] + fn test_strip_marker_tags_inline() { + assert_eq!(strip_marker_tags("no markers here"), "no markers here"); + assert_eq!(strip_marker_tags("before<|marker_5|>after"), "beforeafter"); + assert_eq!( + strip_marker_tags("line1\n<|marker_3|>\nline2"), + "line1\nline2" + ); + } +} diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 1dd675e8b39ccab8403682beb040a075381aaf1d..41d02478c33ce807bf1771cf25799c9a427e63ed 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -1,4 +1,5 @@ pub mod excerpt_ranges; +pub mod multi_region; use anyhow::{Result, anyhow}; use serde::{Deserialize, Serialize}; @@ -81,6 +82,7 @@ pub enum ZetaFormat { v0226Hashline, V0304VariableEdit, V0304SeedNoEdits, + V0306SeedMultiRegions, } impl std::fmt::Display for ZetaFormat { @@ -218,6 +220,20 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] ZetaFormat::v0226Hashline => hashline::special_tokens(), ZetaFormat::V0304VariableEdit => v0304_variable_edit::special_tokens(), ZetaFormat::V0304SeedNoEdits => seed_coder::special_tokens(), + ZetaFormat::V0306SeedMultiRegions => { + static TOKENS: &[&str] = &[ + seed_coder::FIM_SUFFIX, + seed_coder::FIM_PREFIX, + seed_coder::FIM_MIDDLE, + seed_coder::FILE_MARKER, + seed_coder::START_MARKER, + seed_coder::SEPARATOR, + seed_coder::END_MARKER, + CURSOR_MARKER, + multi_region::MARKER_TAG_PREFIX, + ]; + TOKENS + } } } @@ -231,6 +247,7 @@ pub fn token_limits_for_format(format: ZetaFormat) -> (usize, usize) { | ZetaFormat::V0211Prefill | ZetaFormat::V0211SeedCoder | ZetaFormat::v0226Hashline + | ZetaFormat::V0306SeedMultiRegions | ZetaFormat::V0304SeedNoEdits => (350, 150), ZetaFormat::V0304VariableEdit => (1024, 0), } @@ -247,6 +264,7 @@ pub fn stop_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] { | ZetaFormat::V0211Prefill | ZetaFormat::V0211SeedCoder | ZetaFormat::V0304VariableEdit + | ZetaFormat::V0306SeedMultiRegions | ZetaFormat::V0304SeedNoEdits => &[], } } @@ -269,7 +287,8 @@ pub fn excerpt_ranges_for_format( | ZetaFormat::V0211Prefill | ZetaFormat::V0211SeedCoder | ZetaFormat::v0226Hashline - | ZetaFormat::V0304SeedNoEdits => ( + | ZetaFormat::V0304SeedNoEdits + | ZetaFormat::V0306SeedMultiRegions => ( ranges.editable_350.clone(), ranges.editable_350_context_150.clone(), ), @@ -344,9 +363,46 @@ pub fn write_cursor_excerpt_section_for_format( ZetaFormat::V0304VariableEdit => { v0304_variable_edit::write_cursor_excerpt_section(prompt, path, context, cursor_offset) } + ZetaFormat::V0306SeedMultiRegions => { + prompt.push_str(&build_v0306_cursor_prefix( + path, + context, + editable_range, + cursor_offset, + )); + } } } +fn build_v0306_cursor_prefix( + path: &Path, + context: &str, + editable_range: &Range, + cursor_offset: usize, +) -> String { + let mut section = String::new(); + let path_str = path.to_string_lossy(); + write!(section, "{}{}\n", seed_coder::FILE_MARKER, path_str).ok(); + + section.push_str(&context[..editable_range.start]); + section.push_str(seed_coder::START_MARKER); + + let editable_text = &context[editable_range.clone()]; + let cursor_in_editable = cursor_offset - editable_range.start; + multi_region::write_editable_with_markers( + &mut section, + editable_text, + cursor_in_editable, + CURSOR_MARKER, + ); + + if !section.ends_with('\n') { + section.push('\n'); + } + section.push_str(seed_coder::SEPARATOR); + section +} + fn offset_range_to_row_range(text: &str, range: Range) -> Range { let start_row = text[0..range.start].matches('\n').count() as u32; let mut end_row = start_row + text[range.clone()].matches('\n').count() as u32; @@ -392,6 +448,18 @@ pub fn format_prompt_with_budget_for_format( max_tokens, ) } + ZetaFormat::V0306SeedMultiRegions => { + let cursor_prefix = + build_v0306_cursor_prefix(path, context, &editable_range, cursor_offset); + seed_coder::assemble_fim_prompt( + context, + &editable_range, + &cursor_prefix, + &input.events, + related_files, + max_tokens, + ) + } _ => { let mut cursor_section = String::new(); write_cursor_excerpt_section_for_format( @@ -463,7 +531,7 @@ pub fn get_prefill_for_format( | ZetaFormat::V0211SeedCoder | ZetaFormat::v0226Hashline | ZetaFormat::V0304VariableEdit => String::new(), - ZetaFormat::V0304SeedNoEdits => String::new(), + ZetaFormat::V0304SeedNoEdits | ZetaFormat::V0306SeedMultiRegions => String::new(), } } @@ -472,7 +540,9 @@ pub fn output_end_marker_for_format(format: ZetaFormat) -> Option<&'static str> ZetaFormat::V0120GitMergeMarkers => Some(v0120_git_merge_markers::END_MARKER), ZetaFormat::V0131GitMergeMarkersPrefix => Some(v0131_git_merge_markers_prefix::END_MARKER), ZetaFormat::V0211Prefill => Some(v0131_git_merge_markers_prefix::END_MARKER), - ZetaFormat::V0211SeedCoder | ZetaFormat::V0304SeedNoEdits => Some(seed_coder::END_MARKER), + ZetaFormat::V0211SeedCoder + | ZetaFormat::V0304SeedNoEdits + | ZetaFormat::V0306SeedMultiRegions => Some(seed_coder::END_MARKER), ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion @@ -497,7 +567,9 @@ pub fn encode_patch_as_output_for_format( cursor_offset, ) .map(Some), - ZetaFormat::V0304SeedNoEdits => Ok(seed_coder::no_edits(patch)), + ZetaFormat::V0304SeedNoEdits | ZetaFormat::V0306SeedMultiRegions => { + Ok(seed_coder::no_edits(patch)) + } _ => Ok(None), } } @@ -543,6 +615,14 @@ pub fn parse_zeta2_model_output( output.to_string() }, ), + ZetaFormat::V0306SeedMultiRegions => ( + editable_range_in_context, + if output.starts_with(seed_coder::NO_EDITS) { + old_editable_region.to_string() + } else { + multi_region::apply_marker_span(old_editable_region, output)? + }, + ), _ => (editable_range_in_context, output.to_string()), }; @@ -2587,9 +2667,27 @@ pub mod seed_coder { related_files: &[RelatedFile], max_tokens: usize, ) -> String { - let suffix_section = build_suffix_section(context, editable_range); let cursor_prefix_section = build_cursor_prefix_section(path, context, editable_range, cursor_offset); + assemble_fim_prompt( + context, + editable_range, + &cursor_prefix_section, + events, + related_files, + max_tokens, + ) + } + + pub fn assemble_fim_prompt( + context: &str, + editable_range: &Range, + cursor_prefix_section: &str, + events: &[Arc], + related_files: &[RelatedFile], + max_tokens: usize, + ) -> String { + let suffix_section = build_suffix_section(context, editable_range); let suffix_tokens = estimate_tokens(suffix_section.len()); let cursor_prefix_tokens = estimate_tokens(cursor_prefix_section.len()); @@ -2622,7 +2720,7 @@ pub mod seed_coder { if !edit_history_section.is_empty() { prompt.push('\n'); } - prompt.push_str(&cursor_prefix_section); + prompt.push_str(cursor_prefix_section); prompt.push_str(FIM_MIDDLE); prompt }