Detailed changes
@@ -19,7 +19,13 @@ use zeta_prompt::format_zeta_prompt;
use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
pub const MAX_CONTEXT_TOKENS: usize = 350;
-pub const MAX_EDITABLE_TOKENS: usize = 150;
+
+pub fn max_editable_tokens(version: ZetaVersion) -> usize {
+ match version {
+ ZetaVersion::V0112_MiddleAtEnd | ZetaVersion::V0113_Ordered => 150,
+ ZetaVersion::V0114_180EditableRegion => 180,
+ }
+}
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
@@ -61,6 +67,7 @@ pub fn request_prediction_with_zeta2(
events,
excerpt_path,
cursor_offset,
+ zeta_version,
);
let prompt = format_zeta_prompt(&prompt_input, zeta_version);
@@ -202,6 +209,7 @@ pub fn zeta2_prompt_input(
events: Vec<Arc<zeta_prompt::Event>>,
excerpt_path: Arc<Path>,
cursor_offset: usize,
+ zeta_version: ZetaVersion,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
@@ -209,7 +217,7 @@ pub fn zeta2_prompt_input(
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
snapshot,
- MAX_EDITABLE_TOKENS,
+ max_editable_tokens(zeta_version),
MAX_CONTEXT_TOKENS,
);
@@ -45,18 +45,19 @@ pub async fn run_format_prompt(
let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
let snapshot = cx.background_spawn(snapshot_fut).await;
- let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
- cursor_point,
- &snapshot,
- edit_prediction::zeta2::MAX_EDITABLE_TOKENS,
- edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
- );
- let editable_range = editable_range.to_offset(&snapshot);
- let context_range = context_range.to_offset(&snapshot);
-
match args.provider {
- PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
+ PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
step_progress.set_substatus("formatting teacher prompt");
+
+ let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ edit_prediction::zeta2::max_editable_tokens(version),
+ edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
+ );
+ let editable_range = editable_range.to_offset(&snapshot);
+ let context_range = context_range.to_offset(&snapshot);
+
let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
example.prompt = Some(ExamplePrompt {
input: prompt,
@@ -72,6 +73,15 @@ pub async fn run_format_prompt(
PredictionProvider::Zeta2(version) => {
step_progress.set_substatus("formatting zeta2 prompt");
+ let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ edit_prediction::zeta2::max_editable_tokens(version),
+ edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
+ );
+ let editable_range = editable_range.to_offset(&snapshot);
+ let context_range = context_range.to_offset(&snapshot);
+
let context_start = context_range.start;
let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
let editable_range_in_excerpt =
@@ -195,8 +195,8 @@ enum PredictionProvider {
Mercury,
Zeta1,
Zeta2(ZetaVersion),
- Teacher,
- TeacherNonBatching,
+ Teacher(ZetaVersion),
+ TeacherNonBatching(ZetaVersion),
}
impl Default for PredictionProvider {
@@ -212,8 +212,10 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
- PredictionProvider::Teacher => write!(f, "teacher"),
- PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"),
+ PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"),
+ PredictionProvider::TeacherNonBatching(version) => {
+ write!(f, "teacher-non-batching:{version}")
+ }
}
}
}
@@ -221,29 +223,31 @@ impl std::fmt::Display for PredictionProvider {
impl std::str::FromStr for PredictionProvider {
type Err = anyhow::Error;
- fn from_str(s: &str) -> Result<Self, Self::Err> {
+ fn from_str(mut s: &str) -> Result<Self, Self::Err> {
+ let mut version = ZetaVersion::default();
+ if let Some((first, second)) = s.split_once(':') {
+ version = ZetaVersion::parse(second)?;
+ s = first;
+ }
+
let s_lower = s.to_lowercase();
match s_lower.as_str() {
"sweep" => Ok(PredictionProvider::Sweep),
"mercury" => Ok(PredictionProvider::Mercury),
"zeta1" => Ok(PredictionProvider::Zeta1),
- // Handle both old format "zeta2" and new format with version
- "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())),
- "teacher" => Ok(PredictionProvider::Teacher),
+ "zeta2" => Ok(PredictionProvider::Zeta2(version)),
+ "teacher" => Ok(PredictionProvider::Teacher(version)),
"teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => {
- Ok(PredictionProvider::TeacherNonBatching)
+ Ok(PredictionProvider::TeacherNonBatching(version))
}
- _ if s_lower.starts_with("zeta2:") => {
- let version_str = &s[6..];
- let version = ZetaVersion::parse(version_str)?;
- Ok(PredictionProvider::Zeta2(version))
- }
- _ => anyhow::bail!(
- "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
+ _ => {
+ anyhow::bail!(
+ "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher-non-batching\n\
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
Available zeta versions:\n{}",
- ZetaVersion::options_as_string()
- ),
+ ZetaVersion::options_as_string()
+ )
+ }
}
}
}
@@ -20,6 +20,7 @@ use std::{
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
+use zeta_prompt::ZetaVersion;
static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
@@ -42,10 +43,9 @@ pub async fn run_prediction(
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
- if matches!(
- provider,
- PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
- ) {
+ if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
+ args.provider
+ {
let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
run_format_prompt(
@@ -56,8 +56,8 @@ pub async fn run_prediction(
)
.await?;
- let batched = matches!(provider, PredictionProvider::Teacher);
- return predict_anthropic(example, repetition_count, batched).await;
+ let batched = matches!(provider, PredictionProvider::Teacher(..));
+ return predict_anthropic(example, repetition_count, version, batched).await;
}
run_load_project(example, app_state.clone(), cx.clone()).await?;
@@ -96,7 +96,7 @@ pub async fn run_prediction(
}
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
- PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
+ PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
unreachable!()
}
};
@@ -246,6 +246,7 @@ pub async fn run_prediction(
async fn predict_anthropic(
example: &mut Example,
_repetition_count: usize,
+ version: ZetaVersion,
batched: bool,
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
@@ -292,7 +293,11 @@ async fn predict_anthropic(
let prediction = ExamplePrediction {
actual_patch,
actual_output,
- provider: PredictionProvider::Teacher,
+ provider: if batched {
+ PredictionProvider::Teacher(version)
+ } else {
+ PredictionProvider::TeacherNonBatching(version)
+ },
};
example.predictions.push(prediction);
@@ -301,7 +306,7 @@ async fn predict_anthropic(
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
- PredictionProvider::Teacher => {
+ PredictionProvider::Teacher(..) => {
let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
.expect("Failed to create Anthropic client")
@@ -34,8 +34,9 @@ pub struct ZetaPromptInput {
#[allow(non_camel_case_types)]
pub enum ZetaVersion {
V0112_MiddleAtEnd,
- #[default]
V0113_Ordered,
+ #[default]
+ V0114_180EditableRegion,
}
impl std::fmt::Display for ZetaVersion {
@@ -72,10 +73,6 @@ impl ZetaVersion {
.collect::<Vec<_>>()
.concat()
}
-
- pub fn default_as_string() -> String {
- <&'static str>::from(Self::default()).to_string()
- }
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -140,7 +137,7 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri
ZetaVersion::V0112_MiddleAtEnd => {
v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
}
- ZetaVersion::V0113_Ordered => {
+ ZetaVersion::V0113_Ordered | ZetaVersion::V0114_180EditableRegion => {
v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
}
}