ep: Remove code duplication and extend prompt size limit (#54453)

Oleksiy Syvokon created

V0327 requires a larger prompt limit.


Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/format_prompt.rs | 131 ------------------
crates/zeta_prompt/src/zeta_prompt.rs           |  35 ++++
2 files changed, 35 insertions(+), 131 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -7,13 +7,10 @@ use crate::{
 };
 use anyhow::{Context as _, Result, anyhow};
 use gpui::AsyncApp;
-use similar::DiffableStr;
 use std::ops::Range;
 use std::sync::Arc;
-use zeta_prompt::udiff;
 use zeta_prompt::{
-    ZetaFormat, encode_patch_as_output_for_format, format_zeta_prompt, multi_region,
-    output_end_marker_for_format, resolve_cursor_region,
+    ZetaFormat, format_expected_output, format_zeta_prompt, multi_region, resolve_cursor_region,
 };
 
 fn resolved_excerpt_ranges_for_format(
@@ -88,17 +85,17 @@ pub async fn run_format_prompt(
                 .into_iter()
                 .next()
                 .and_then(|(expected_patch, expected_cursor_offset)| {
-                    zeta2_output_for_patch(
+                    format_expected_output(
                         prompt_inputs,
+                        zeta_format,
                         &expected_patch,
                         expected_cursor_offset,
-                        zeta_format,
                     )
                     .ok()
                 });
 
             let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
-                zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
+                format_expected_output(prompt_inputs, zeta_format, patch, None).ok()
             });
 
             example.prompt = prompt.map(|prompt| ExamplePrompt {
@@ -116,126 +113,6 @@ pub async fn run_format_prompt(
     Ok(())
 }
 
-pub fn zeta2_output_for_patch(
-    input: &zeta_prompt::ZetaPromptInput,
-    patch: &str,
-    cursor_offset: Option<usize>,
-    version: ZetaFormat,
-) -> Result<String> {
-    let (context, editable_range, _, _) = resolve_cursor_region(input, version);
-    let mut old_editable_region = context[editable_range].to_string();
-
-    if !old_editable_region.ends_with_newline() {
-        old_editable_region.push('\n');
-    }
-
-    if let Some(encoded_output) =
-        encode_patch_as_output_for_format(version, &old_editable_region, patch, cursor_offset)?
-    {
-        return Ok(encoded_output);
-    }
-
-    let (result, first_hunk_offset) =
-        udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable_region).with_context(
-            || {
-                format!(
-                    "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
-                    patch, old_editable_region
-                )
-            },
-        )?;
-
-    if version == ZetaFormat::V0317SeedMultiRegions {
-        let cursor_in_new = cursor_offset.map(|cursor_offset| {
-            let hunk_start = first_hunk_offset.unwrap_or(0);
-            result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
-        });
-        return multi_region::encode_from_old_and_new_v0317(
-            &old_editable_region,
-            &result,
-            cursor_in_new,
-            zeta_prompt::CURSOR_MARKER,
-            multi_region::V0317_END_MARKER,
-        );
-    }
-
-    if version == ZetaFormat::V0318SeedMultiRegions {
-        let cursor_in_new = cursor_offset.map(|cursor_offset| {
-            let hunk_start = first_hunk_offset.unwrap_or(0);
-            result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
-        });
-        return multi_region::encode_from_old_and_new_v0318(
-            &old_editable_region,
-            &result,
-            cursor_in_new,
-            zeta_prompt::CURSOR_MARKER,
-            multi_region::V0318_END_MARKER,
-        );
-    }
-
-    if version == ZetaFormat::V0327SingleFile {
-        let cursor_in_new = cursor_offset.map(|cursor_offset| {
-            let hunk_start = first_hunk_offset.unwrap_or(0);
-            result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
-        });
-        return multi_region::encode_from_old_and_new_v0318(
-            &old_editable_region,
-            &result,
-            cursor_in_new,
-            zeta_prompt::CURSOR_MARKER,
-            multi_region::V0327_END_MARKER,
-        );
-    }
-
-    if version == ZetaFormat::V0316SeedMultiRegions {
-        let cursor_in_new = cursor_offset.map(|cursor_offset| {
-            let hunk_start = first_hunk_offset.unwrap_or(0);
-            result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
-        });
-        return multi_region::encode_from_old_and_new_v0316(
-            &old_editable_region,
-            &result,
-            cursor_in_new,
-            zeta_prompt::CURSOR_MARKER,
-            multi_region::V0316_END_MARKER,
-        );
-    }
-
-    if version == ZetaFormat::V0306SeedMultiRegions {
-        let cursor_in_new = cursor_offset.map(|cursor_offset| {
-            let hunk_start = first_hunk_offset.unwrap_or(0);
-            result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()))
-        });
-        return multi_region::encode_from_old_and_new(
-            &old_editable_region,
-            &result,
-            cursor_in_new,
-            zeta_prompt::CURSOR_MARKER,
-            zeta_prompt::seed_coder::END_MARKER,
-            zeta_prompt::seed_coder::NO_EDITS,
-        );
-    }
-
-    let mut result = result;
-    if let Some(cursor_offset) = cursor_offset {
-        // The cursor_offset is relative to the start of the hunk's new text (context + additions).
-        // We need to add where the hunk context matched in the editable region to compute
-        // the actual cursor position in the result.
-        let hunk_start = first_hunk_offset.unwrap_or(0);
-        let offset = result.floor_char_boundary((hunk_start + cursor_offset).min(result.len()));
-        result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
-    }
-
-    if let Some(end_marker) = output_end_marker_for_format(version) {
-        if !result.ends_with('\n') {
-            result.push('\n');
-        }
-        result.push_str(end_marker);
-    }
-
-    Ok(result)
-}
-
 pub struct TeacherPrompt;
 
 impl TeacherPrompt {

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -15,7 +15,6 @@ pub use crate::excerpt_ranges::{
 };
 
 pub const CURSOR_MARKER: &str = "<|user_cursor|>";
-pub const MAX_PROMPT_TOKENS: usize = 4096;
 
 /// Use up to this amount of the editable region for prefill.
 /// Larger values may result in more robust generation, but
@@ -230,7 +229,25 @@ pub fn prompt_input_contains_special_tokens(input: &ZetaPromptInput, format: Zet
 }
 
 pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> Option<String> {
-    format_prompt_with_budget_for_format(input, format, MAX_PROMPT_TOKENS)
+    let max_prompt_tokens = match format {
+        ZetaFormat::V0112MiddleAtEnd
+        | ZetaFormat::V0113Ordered
+        | ZetaFormat::V0114180EditableRegion
+        | ZetaFormat::V0120GitMergeMarkers
+        | ZetaFormat::V0131GitMergeMarkersPrefix
+        | ZetaFormat::V0211Prefill
+        | ZetaFormat::V0211SeedCoder
+        | ZetaFormat::v0226Hashline
+        | ZetaFormat::V0304VariableEdit
+        | ZetaFormat::V0304SeedNoEdits
+        | ZetaFormat::V0306SeedMultiRegions
+        | ZetaFormat::V0316SeedMultiRegions
+        | ZetaFormat::V0317SeedMultiRegions
+        | ZetaFormat::V0318SeedMultiRegions => 4096,
+        ZetaFormat::V0327SingleFile => 16384,
+    };
+
+    format_prompt_with_budget_for_format(input, format, max_prompt_tokens)
 }
 
 pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
@@ -973,7 +990,7 @@ pub fn format_expected_output(
                 multi_region::V0316_END_MARKER,
             )
         }
-        ZetaFormat::V0318SeedMultiRegions => {
+        ZetaFormat::V0318SeedMultiRegions | ZetaFormat::V0327SingleFile => {
             let (new_editable, first_hunk_offset) =
                 udiff::apply_diff_to_string_with_hunk_offset(patch, &old_editable)?;
             let cursor_in_new = cursor_in_new_text(cursor_offset, first_hunk_offset, &new_editable);
@@ -999,7 +1016,17 @@ pub fn format_expected_output(
         }
         // V0131-style formats and fallback: produce new editable text with
         // cursor marker inserted, followed by the end marker.
-        _ => {
+        ZetaFormat::V0112MiddleAtEnd
+        | ZetaFormat::V0113Ordered
+        | ZetaFormat::V0114180EditableRegion
+        | ZetaFormat::V0120GitMergeMarkers
+        | ZetaFormat::V0131GitMergeMarkersPrefix
+        | ZetaFormat::V0211Prefill
+        | ZetaFormat::V0211SeedCoder
+        | ZetaFormat::v0226Hashline
+        | ZetaFormat::V0304VariableEdit
+        | ZetaFormat::V0304SeedNoEdits
+        | ZetaFormat::V0306SeedMultiRegions => {
             let (mut result, first_hunk_offset) = if empty_patch {
                 (old_editable.clone(), None)
             } else {