ep: Add `<|no-edit|>` command to hashlines format (#51103)

Ben Kunkle created

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction/src/zeta.rs    |  12 +
crates/zeta_prompt/src/zeta_prompt.rs | 135 +++++++++++++++++++++-------
2 files changed, 110 insertions(+), 37 deletions(-)

Detailed changes

crates/edit_prediction/src/zeta.rs 🔗

@@ -24,7 +24,7 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput};
 use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 use zeta_prompt::{
     CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
-    prompt_input_contains_special_tokens,
+    prompt_input_contains_special_tokens, stop_tokens_for_format,
     zeta1::{self, EDITABLE_REGION_END_MARKER},
 };
 
@@ -192,7 +192,10 @@ pub fn request_prediction_with_zeta(
                                 custom_settings,
                                 prompt,
                                 max_tokens,
-                                vec![],
+                                stop_tokens_for_format(zeta_version)
+                                    .iter()
+                                    .map(|token| token.to_string())
+                                    .collect(),
                                 open_ai_compatible_api_key.clone(),
                                 &http_client,
                             )
@@ -226,7 +229,10 @@ pub fn request_prediction_with_zeta(
                         model: config.model_id.clone().unwrap_or_default(),
                         prompt,
                         temperature: None,
-                        stop: vec![],
+                        stop: stop_tokens_for_format(config.format)
+                            .iter()
+                            .map(|token| std::borrow::Cow::Borrowed(*token))
+                            .collect(),
                         max_tokens: Some(2048),
                         environment,
                     };

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -222,6 +222,21 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str]
     }
 }
 
+pub fn stop_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
+    match format {
+        ZetaFormat::v0226Hashline => &[hashline::NO_EDITS_COMMAND_MARKER],
+        ZetaFormat::V0112MiddleAtEnd
+        | ZetaFormat::V0113Ordered
+        | ZetaFormat::V0114180EditableRegion
+        | ZetaFormat::V0120GitMergeMarkers
+        | ZetaFormat::V0131GitMergeMarkersPrefix
+        | ZetaFormat::V0211Prefill
+        | ZetaFormat::V0211SeedCoder
+        | ZetaFormat::V0304VariableEdit
+        | ZetaFormat::V0304SeedNoEdits => &[],
+    }
+}
+
 pub fn excerpt_ranges_for_format(
     format: ZetaFormat,
     ranges: &ExcerptRanges,
@@ -1010,12 +1025,14 @@ pub mod hashline {
 
     const SET_COMMAND_MARKER: &str = "<|set|>";
     const INSERT_COMMAND_MARKER: &str = "<|insert|>";
+    pub const NO_EDITS_COMMAND_MARKER: &str = "<|no_edits|>";
 
     pub fn special_tokens() -> &'static [&'static str] {
         return &[
             SET_COMMAND_MARKER,
             "<|set_range|>",
             INSERT_COMMAND_MARKER,
+            NO_EDITS_COMMAND_MARKER,
             CURSOR_MARKER,
             "<|file_sep|>",
             "<|fim_prefix|>",
@@ -1109,6 +1126,7 @@ pub mod hashline {
         }
 
         prompt.push_str(END_MARKER);
+        prompt.push('\n');
     }
 
     /// A single edit command parsed from the model output.
@@ -1234,7 +1252,9 @@ pub mod hashline {
     }
 
     pub fn output_has_edit_commands(model_output: &str) -> bool {
-        model_output.contains(SET_COMMAND_MARKER) || model_output.contains(INSERT_COMMAND_MARKER)
+        model_output.contains(SET_COMMAND_MARKER)
+            || model_output.contains(INSERT_COMMAND_MARKER)
+            || model_output.contains(NO_EDITS_COMMAND_MARKER)
     }
 
     /// Apply `<|set|>` and `<|insert|>` edit commands from the model output to the
@@ -1245,6 +1265,13 @@ pub mod hashline {
     ///
     /// Returns the full replacement text for the editable region.
     pub fn apply_edit_commands(editable_region: &str, model_output: &str) -> String {
+        if model_output
+            .trim_start()
+            .starts_with(NO_EDITS_COMMAND_MARKER)
+        {
+            return editable_region.to_string();
+        }
+
         let original_lines: Vec<&str> = editable_region.lines().collect();
         let old_hashes: Vec<u8> = original_lines
             .iter()
@@ -1549,6 +1576,10 @@ pub mod hashline {
             result.pop();
         }
 
+        if result.is_empty() {
+            return Ok(NO_EDITS_COMMAND_MARKER.to_string());
+        }
+
         Ok(result)
     }
 
@@ -1579,7 +1610,8 @@ pub mod hashline {
                     <|fim_middle|>current
                     0:5c|hello<|user_cursor|> world
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "multiline_cursor_on_second_line",
@@ -1594,7 +1626,8 @@ pub mod hashline {
                     1:26|b<|user_cursor|>bb
                     2:29|ccc
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "no_trailing_newline_in_context",
@@ -1608,7 +1641,8 @@ pub mod hashline {
                     0:d9|lin<|user_cursor|>e1
                     1:da|line2
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "leading_newline_in_editable_region",
@@ -1622,7 +1656,8 @@ pub mod hashline {
                     0:00|
                     1:26|a<|user_cursor|>bc
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "with_suffix",
@@ -1636,7 +1671,8 @@ pub mod hashline {
                     0:26|ab<|user_cursor|>c
                     <|fim_suffix|>
                     def
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "unicode_two_byte_chars",
@@ -1649,7 +1685,8 @@ pub mod hashline {
                     <|fim_middle|>current
                     0:1b|hé<|user_cursor|>llo
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "unicode_three_byte_chars",
@@ -1662,7 +1699,8 @@ pub mod hashline {
                     <|fim_middle|>current
                     0:80|日本<|user_cursor|>語
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "unicode_four_byte_chars",
@@ -1675,7 +1713,8 @@ pub mod hashline {
                     <|fim_middle|>current
                     0:6b|a🌍<|user_cursor|>b
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "cursor_at_start_of_region_not_placed",
@@ -1688,7 +1727,8 @@ pub mod hashline {
                     <|fim_middle|>current
                     0:26|abc
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "cursor_at_end_of_line_not_placed",
@@ -1702,7 +1742,8 @@ pub mod hashline {
                     0:26|abc
                     1:2f|def
                     <|fim_suffix|>
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
                 Case {
                     name: "cursor_offset_relative_to_context_not_editable_region",
@@ -1721,7 +1762,8 @@ pub mod hashline {
                     1:26|b<|user_cursor|>bb
                     <|fim_suffix|>
                     suf
-                    <|fim_middle|>updated"},
+                    <|fim_middle|>updated
+                    "},
                 },
             ];
 
@@ -1894,6 +1936,18 @@ pub mod hashline {
                     world
                 "},
                 },
+                Case {
+                    name: "no_edits_command_returns_original",
+                    original: indoc! {"
+                    hello
+                    world
+                "},
+                    model_output: "<|no_edits|>",
+                    expected: indoc! {"
+                    hello
+                    world
+                "},
+                },
                 Case {
                     name: "wrong_hash_set_ignored",
                     original: indoc! {"
@@ -2113,6 +2167,7 @@ pub mod hashline {
             )));
             assert!(!hashline::output_has_edit_commands("just plain text"));
             assert!(!hashline::output_has_edit_commands("NO_EDITS"));
+            assert!(hashline::output_has_edit_commands("<|no_edits|>"));
         }
 
         // ---- hashline::patch_to_edit_commands round-trip tests ----
@@ -2350,35 +2405,47 @@ pub mod hashline {
                     }
                 "#},
                     patch: indoc! {r#"
-                    @@ -1,3 +1,3 @@
-                     fn main() {
-                    -    println!();
-                    +    eprintln!("");
-                     }
-                "#},
+                        @@ -1,3 +1,3 @@
+                        fn main() {
+                        -    println!();
+                        +    eprintln!("");
+                        }
+                    "#},
                     expected_new: indoc! {r#"
-                    fn main() {
-                        eprintln!("<|user_cursor|>");
-                    }
-                "#},
+                        fn main() {
+                            eprintln!("<|user_cursor|>");
+                        }
+                    "#},
                 },
                 Case {
                     name: "non_local_hunk_header_pure_insertion_repro",
                     old: indoc! {"
-                    aaa
-                    bbb
-                "},
+                        aaa
+                        bbb
+                    "},
                     patch: indoc! {"
-                    @@ -20,2 +20,3 @@
-                     aaa
-                    +xxx
-                     bbb
-                "},
+                        @@ -20,2 +20,3 @@
+                        aaa
+                        +xxx
+                        bbb
+                    "},
                     expected_new: indoc! {"
-                    aaa
-                    xxx
-                    bbb
-                "},
+                        aaa
+                        xxx
+                        bbb
+                    "},
+                },
+                Case {
+                    name: "empty_patch_produces_no_edits_marker",
+                    old: indoc! {"
+                        aaa
+                        bbb
+                    "},
+                    patch: "@@ -20,2 +20,3 @@\n",
+                    expected_new: indoc! {"
+                        aaa
+                        bbb
+                    "},
                 },
             ];