ep: Add V0131GitMergeMarkersPrefix prompt format (#48145)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/zeta2.rs             |  4 
crates/edit_prediction_cli/src/format_prompt.rs | 11 ++-
crates/edit_prediction_cli/src/parse_output.rs  | 22 ++++--
crates/zeta_prompt/src/zeta_prompt.rs           | 65 +++++++++++++++++++
4 files changed, 90 insertions(+), 12 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta2.rs 🔗

@@ -23,7 +23,9 @@ pub const MAX_CONTEXT_TOKENS: usize = 350;
 pub fn max_editable_tokens(version: ZetaVersion) -> usize {
     match version {
         ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
-        ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
+        ZetaVersion::V0114180EditableRegion => 180,
+        ZetaVersion::V0120GitMergeMarkers => 180,
+        ZetaVersion::V0131GitMergeMarkersPrefix => 180,
     }
 }
 

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -154,11 +154,14 @@ pub fn zeta2_output_for_patch(
         result.insert_str(offset, zeta_prompt::CURSOR_MARKER);
     }
 
-    if version == ZetaVersion::V0120GitMergeMarkers {
-        if !result.ends_with('\n') {
-            result.push('\n');
+    match version {
+        ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => {
+            if !result.ends_with('\n') {
+                result.push('\n');
+            }
+            result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
         }
-        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
+        _ => (),
     }
 
     Ok(result)

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -56,7 +56,7 @@ fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<St
         ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
             ("<|fim_middle|>current\n", "<|fim_suffix|>")
         }
-        ZetaVersion::V0120GitMergeMarkers => (
+        ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
             zeta_prompt::v0120_git_merge_markers::START_MARKER,
             zeta_prompt::v0120_git_merge_markers::SEPARATOR,
         ),
@@ -76,7 +76,9 @@ fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<St
 
     let region = &prompt[start..end];
     let region = region.strip_suffix('\n').unwrap_or(region);
-    Ok(region.replace(CURSOR_MARKER, ""))
+    let region = region.replace(CURSOR_MARKER, "");
+
+    Ok(region)
 }
 
 fn parse_zeta2_output(
@@ -100,12 +102,18 @@ fn parse_zeta2_output(
         None
     };
 
-    if version == ZetaVersion::V0120GitMergeMarkers {
-        if let Some(stripped) =
-            new_text.strip_suffix(zeta_prompt::v0120_git_merge_markers::END_MARKER)
-        {
-            new_text = stripped.to_string();
+    let suffix = match version {
+        ZetaVersion::V0131GitMergeMarkersPrefix => {
+            zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
         }
+        ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
+        _ => "",
+    };
+    if !suffix.is_empty() {
+        new_text = new_text
+            .strip_suffix(suffix)
+            .unwrap_or(&new_text)
+            .to_string();
     }
 
     let mut old_text_normalized = old_text.clone();

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -45,6 +45,7 @@ pub enum ZetaVersion {
     #[default]
     V0114180EditableRegion,
     V0120GitMergeMarkers,
+    V0131GitMergeMarkersPrefix,
 }
 
 impl std::fmt::Display for ZetaVersion {
@@ -156,6 +157,9 @@ fn format_zeta_prompt_with_budget(
         ZetaVersion::V0120GitMergeMarkers => {
             v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
         }
+        ZetaVersion::V0131GitMergeMarkersPrefix => {
+            v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input)
+        }
     }
 
     let cursor_tokens = estimate_tokens(cursor_section.len());
@@ -418,6 +422,67 @@ pub mod v0120_git_merge_markers {
     }
 }
 
+pub mod v0131_git_merge_markers_prefix {
+    //! A prompt that uses git-style merge conflict markers to represent the editable region.
+    //!
+    //! Example prompt:
+    //!
+    //! <|file_sep|>path/to/target_file.py
+    //! <|fim_prefix|>
+    //! code before editable region
+    //! <<<<<<< CURRENT
+    //! code that
+    //! needs to<|user_cursor|>
+    //! be rewritten
+    //! =======
+    //! <|fim_suffix|>
+    //! code after editable region
+    //! <|fim_middle|>
+    //!
+    //! Expected output (should be generated by the model):
+    //!
+    //! updated
+    //! code with
+    //! changes applied
+    //! >>>>>>> UPDATED
+
+    use super::*;
+
+    pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
+    pub const SEPARATOR: &str = "=======\n";
+    pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
+
+    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+        let path_str = input.cursor_path.to_string_lossy();
+        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+
+        prompt.push_str("<|fim_prefix|>");
+        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+        prompt.push_str(START_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+        );
+        prompt.push_str(CURSOR_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+        );
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+        prompt.push_str(SEPARATOR);
+
+        prompt.push_str("<|fim_suffix|>");
+        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>");
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;