ep: Add a prompt with git-style merge markers (#47215)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/zeta2.rs             | 12 ++
crates/edit_prediction_cli/src/format_prompt.rs | 30 ++++++--
crates/zeta_prompt/src/zeta_prompt.rs           | 66 +++++++++++++++++++
3 files changed, 99 insertions(+), 9 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta2.rs 🔗

@@ -14,14 +14,14 @@ use release_channel::AppVersion;
 use std::env;
 use std::{path::Path, sync::Arc, time::Instant};
 use zeta_prompt::format_zeta_prompt;
-use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
+use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
 
 pub const MAX_CONTEXT_TOKENS: usize = 350;
 
 pub fn max_editable_tokens(version: ZetaVersion) -> usize {
     match version {
         ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
-        ZetaVersion::V0114180EditableRegion => 180,
+        ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
     }
 }
 
@@ -147,6 +147,14 @@ pub fn request_prediction_with_zeta2(
                 output_text = output_text.replace(CURSOR_MARKER, "");
             }
 
+            if zeta_version == ZetaVersion::V0120GitMergeMarkers {
+                if let Some(stripped) =
+                    output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
+                {
+                    output_text = stripped.to_string();
+                }
+            }
+
             let mut old_text = snapshot
                 .text_for_range(editable_offset_range.clone())
                 .collect::<String>();

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -12,6 +12,7 @@ use language::{Buffer, OffsetRangeExt, Point};
 use similar::DiffableStr;
 use std::sync::Arc;
 use std::{fmt::Write as _, ops::Range};
+use zeta_prompt::ZetaVersion;
 use zeta_prompt::format_zeta_prompt;
 
 pub async fn run_format_prompt(
@@ -104,6 +105,7 @@ pub async fn run_format_prompt(
                     .first()
                     .context("expected patches is empty")?
                     .clone(),
+                version,
             )?;
             example.prompt = Some(ExamplePrompt {
                 input: prompt,
@@ -118,7 +120,11 @@ pub async fn run_format_prompt(
     Ok(())
 }
 
-pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
+pub fn zeta2_output_for_patch(
+    input: &zeta_prompt::ZetaPromptInput,
+    patch: &str,
+    version: ZetaVersion,
+) -> Result<String> {
     let mut old_editable_region =
         input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
 
@@ -126,12 +132,22 @@ pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str)
         old_editable_region.push('\n');
     }
 
-    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
-        format!(
-            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
-            patch, old_editable_region
-        )
-    })
+    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
+        .with_context(|| {
+            format!(
+                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
+                patch, old_editable_region
+            )
+        })?;
+
+    if version == ZetaVersion::V0120GitMergeMarkers {
+        if !result.ends_with('\n') {
+            result.push('\n');
+        }
+        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
+    }
+
+    Ok(result)
 }
 
 pub struct TeacherPrompt;

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -37,6 +37,7 @@ pub enum ZetaVersion {
     V0113Ordered,
     #[default]
     V0114180EditableRegion,
+    V0120GitMergeMarkers,
 }
 
 impl std::fmt::Display for ZetaVersion {
@@ -140,6 +141,10 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri
         ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
             v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
         }
+
+        ZetaVersion::V0120GitMergeMarkers => {
+            v0120_git_merge_markers::write_cursor_excerpt_section(&mut prompt, input)
+        }
     }
 
     prompt
@@ -238,3 +243,64 @@ mod v0113_ordered {
         prompt.push_str("<|fim_middle|>updated\n");
     }
 }
+
+pub mod v0120_git_merge_markers {
+    //! A prompt that uses git-style merge conflict markers to represent the editable region.
+    //!
+    //! Example prompt:
+    //!
+    //! <|file_sep|>path/to/target_file.py
+    //! <|fim_prefix|>
+    //! code before editable region
+    //! <|fim_suffix|>
+    //! code after editable region
+    //! <|fim_middle|>
+    //! <<<<<<< CURRENT
+    //! code that
+    //! needs to<|user_cursor|>
+    //! be rewritten
+    //! =======
+    //!
+    //! Expected output (should be generated by the model):
+    //!
+    //! updated
+    //! code with
+    //! changes applied
+    //! >>>>>>> UPDATED
+
+    use super::*;
+
+    pub const START_MARKER: &str = "<<<<<<< CURRENT\n";
+    pub const SEPARATOR: &str = "=======\n";
+    pub const END_MARKER: &str = ">>>>>>> UPDATED\n";
+
+    pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+        let path_str = input.cursor_path.to_string_lossy();
+        write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+
+        prompt.push_str("<|fim_prefix|>");
+        prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+
+        prompt.push_str("<|fim_suffix|>");
+        prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+
+        prompt.push_str("<|fim_middle|>");
+        prompt.push_str(START_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+        );
+        prompt.push_str(CURSOR_MARKER);
+        prompt.push_str(
+            &input.cursor_excerpt
+                [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+        );
+        if !prompt.ends_with('\n') {
+            prompt.push('\n');
+        }
+        prompt.push_str(SEPARATOR);
+    }
+}