@@ -24,7 +24,7 @@ use zeta_prompt::{ParsedOutput, ZetaPromptInput};
use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::{
CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
- prompt_input_contains_special_tokens,
+ prompt_input_contains_special_tokens, stop_tokens_for_format,
zeta1::{self, EDITABLE_REGION_END_MARKER},
};
@@ -192,7 +192,10 @@ pub fn request_prediction_with_zeta(
custom_settings,
prompt,
max_tokens,
- vec![],
+ stop_tokens_for_format(zeta_version)
+ .iter()
+ .map(|token| token.to_string())
+ .collect(),
open_ai_compatible_api_key.clone(),
&http_client,
)
@@ -226,7 +229,10 @@ pub fn request_prediction_with_zeta(
model: config.model_id.clone().unwrap_or_default(),
prompt,
temperature: None,
- stop: vec![],
+ stop: stop_tokens_for_format(config.format)
+ .iter()
+ .map(|token| std::borrow::Cow::Borrowed(*token))
+ .collect(),
max_tokens: Some(2048),
environment,
};
@@ -222,6 +222,21 @@ pub fn special_tokens_for_format(format: ZetaFormat) -> &'static [&'static str]
}
}
+pub fn stop_tokens_for_format(format: ZetaFormat) -> &'static [&'static str] {
+ match format {
+ ZetaFormat::v0226Hashline => &[hashline::NO_EDITS_COMMAND_MARKER],
+ ZetaFormat::V0112MiddleAtEnd
+ | ZetaFormat::V0113Ordered
+ | ZetaFormat::V0114180EditableRegion
+ | ZetaFormat::V0120GitMergeMarkers
+ | ZetaFormat::V0131GitMergeMarkersPrefix
+ | ZetaFormat::V0211Prefill
+ | ZetaFormat::V0211SeedCoder
+ | ZetaFormat::V0304VariableEdit
+ | ZetaFormat::V0304SeedNoEdits => &[],
+ }
+}
+
pub fn excerpt_ranges_for_format(
format: ZetaFormat,
ranges: &ExcerptRanges,
@@ -1010,12 +1025,14 @@ pub mod hashline {
const SET_COMMAND_MARKER: &str = "<|set|>";
const INSERT_COMMAND_MARKER: &str = "<|insert|>";
+ pub const NO_EDITS_COMMAND_MARKER: &str = "<|no_edits|>";
pub fn special_tokens() -> &'static [&'static str] {
return &[
SET_COMMAND_MARKER,
"<|set_range|>",
INSERT_COMMAND_MARKER,
+ NO_EDITS_COMMAND_MARKER,
CURSOR_MARKER,
"<|file_sep|>",
"<|fim_prefix|>",
@@ -1109,6 +1126,7 @@ pub mod hashline {
}
prompt.push_str(END_MARKER);
+ prompt.push('\n');
}
/// A single edit command parsed from the model output.
@@ -1234,7 +1252,9 @@ pub mod hashline {
}
pub fn output_has_edit_commands(model_output: &str) -> bool {
- model_output.contains(SET_COMMAND_MARKER) || model_output.contains(INSERT_COMMAND_MARKER)
+ model_output.contains(SET_COMMAND_MARKER)
+ || model_output.contains(INSERT_COMMAND_MARKER)
+ || model_output.contains(NO_EDITS_COMMAND_MARKER)
}
/// Apply `<|set|>` and `<|insert|>` edit commands from the model output to the
@@ -1245,6 +1265,13 @@ pub mod hashline {
///
/// Returns the full replacement text for the editable region.
pub fn apply_edit_commands(editable_region: &str, model_output: &str) -> String {
+ if model_output
+ .trim_start()
+ .starts_with(NO_EDITS_COMMAND_MARKER)
+ {
+ return editable_region.to_string();
+ }
+
let original_lines: Vec<&str> = editable_region.lines().collect();
let old_hashes: Vec<u8> = original_lines
.iter()
@@ -1549,6 +1576,10 @@ pub mod hashline {
result.pop();
}
+ if result.is_empty() {
+ return Ok(NO_EDITS_COMMAND_MARKER.to_string());
+ }
+
Ok(result)
}
@@ -1579,7 +1610,8 @@ pub mod hashline {
<|fim_middle|>current
0:5c|hello<|user_cursor|> world
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "multiline_cursor_on_second_line",
@@ -1594,7 +1626,8 @@ pub mod hashline {
1:26|b<|user_cursor|>bb
2:29|ccc
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "no_trailing_newline_in_context",
@@ -1608,7 +1641,8 @@ pub mod hashline {
0:d9|lin<|user_cursor|>e1
1:da|line2
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "leading_newline_in_editable_region",
@@ -1622,7 +1656,8 @@ pub mod hashline {
0:00|
1:26|a<|user_cursor|>bc
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "with_suffix",
@@ -1636,7 +1671,8 @@ pub mod hashline {
0:26|ab<|user_cursor|>c
<|fim_suffix|>
def
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "unicode_two_byte_chars",
@@ -1649,7 +1685,8 @@ pub mod hashline {
<|fim_middle|>current
0:1b|hé<|user_cursor|>llo
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "unicode_three_byte_chars",
@@ -1662,7 +1699,8 @@ pub mod hashline {
<|fim_middle|>current
0:80|日本<|user_cursor|>語
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "unicode_four_byte_chars",
@@ -1675,7 +1713,8 @@ pub mod hashline {
<|fim_middle|>current
0:6b|a🌍<|user_cursor|>b
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "cursor_at_start_of_region_not_placed",
@@ -1688,7 +1727,8 @@ pub mod hashline {
<|fim_middle|>current
0:26|abc
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "cursor_at_end_of_line_not_placed",
@@ -1702,7 +1742,8 @@ pub mod hashline {
0:26|abc
1:2f|def
<|fim_suffix|>
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
Case {
name: "cursor_offset_relative_to_context_not_editable_region",
@@ -1721,7 +1762,8 @@ pub mod hashline {
1:26|b<|user_cursor|>bb
<|fim_suffix|>
suf
- <|fim_middle|>updated"},
+ <|fim_middle|>updated
+ "},
},
];
@@ -1894,6 +1936,18 @@ pub mod hashline {
world
"},
},
+ Case {
+ name: "no_edits_command_returns_original",
+ original: indoc! {"
+ hello
+ world
+ "},
+ model_output: "<|no_edits|>",
+ expected: indoc! {"
+ hello
+ world
+ "},
+ },
Case {
name: "wrong_hash_set_ignored",
original: indoc! {"
@@ -2113,6 +2167,7 @@ pub mod hashline {
)));
assert!(!hashline::output_has_edit_commands("just plain text"));
assert!(!hashline::output_has_edit_commands("NO_EDITS"));
+ assert!(hashline::output_has_edit_commands("<|no_edits|>"));
}
// ---- hashline::patch_to_edit_commands round-trip tests ----
@@ -2350,35 +2405,47 @@ pub mod hashline {
}
"#},
patch: indoc! {r#"
- @@ -1,3 +1,3 @@
- fn main() {
- - println!();
- + eprintln!("");
- }
- "#},
+ @@ -1,3 +1,3 @@
+ fn main() {
+ - println!();
+ + eprintln!("");
+ }
+ "#},
expected_new: indoc! {r#"
- fn main() {
- eprintln!("<|user_cursor|>");
- }
- "#},
+ fn main() {
+ eprintln!("<|user_cursor|>");
+ }
+ "#},
},
Case {
name: "non_local_hunk_header_pure_insertion_repro",
old: indoc! {"
- aaa
- bbb
- "},
+ aaa
+ bbb
+ "},
patch: indoc! {"
- @@ -20,2 +20,3 @@
- aaa
- +xxx
- bbb
- "},
+ @@ -20,2 +20,3 @@
+ aaa
+ +xxx
+ bbb
+ "},
expected_new: indoc! {"
- aaa
- xxx
- bbb
- "},
+ aaa
+ xxx
+ bbb
+ "},
+ },
+ Case {
+ name: "empty_patch_produces_no_edits_marker",
+ old: indoc! {"
+ aaa
+ bbb
+ "},
+ patch: "@@ -20,2 +20,3 @@\n",
+ expected_new: indoc! {"
+ aaa
+ bbb
+ "},
},
];