diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index 17f379f23eeac36f388dbcf72e00f4c63ed7a053..9a8b9767ceda0c311ce0779fe1c0ac948b9485ce 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -19,7 +19,13 @@ use zeta_prompt::format_zeta_prompt; use zeta_prompt::{CURSOR_MARKER, ZetaVersion}; pub const MAX_CONTEXT_TOKENS: usize = 350; -pub const MAX_EDITABLE_TOKENS: usize = 150; + +pub fn max_editable_tokens(version: ZetaVersion) -> usize { + match version { + ZetaVersion::V0112_MiddleAtEnd | ZetaVersion::V0113_Ordered => 150, + ZetaVersion::V0114_180EditableRegion => 180, + } +} pub fn request_prediction_with_zeta2( store: &mut EditPredictionStore, @@ -61,6 +67,7 @@ pub fn request_prediction_with_zeta2( events, excerpt_path, cursor_offset, + zeta_version, ); let prompt = format_zeta_prompt(&prompt_input, zeta_version); @@ -202,6 +209,7 @@ pub fn zeta2_prompt_input( events: Vec>, excerpt_path: Arc, cursor_offset: usize, + zeta_version: ZetaVersion, ) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { let cursor_point = cursor_offset.to_point(snapshot); @@ -209,7 +217,7 @@ pub fn zeta2_prompt_input( crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( cursor_point, snapshot, - MAX_EDITABLE_TOKENS, + max_editable_tokens(zeta_version), MAX_CONTEXT_TOKENS, ); diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index d876c24726b783102f166049ae0f07e6e7c78d81..3103735ef1a4bbe2328a6ce420750ca54e775787 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -45,18 +45,19 @@ pub async fn run_format_prompt( let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column); let snapshot = cx.background_spawn(snapshot_fut).await; - let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( - cursor_point, - &snapshot, - edit_prediction::zeta2::MAX_EDITABLE_TOKENS, - edit_prediction::zeta2::MAX_CONTEXT_TOKENS, - ); - let editable_range = editable_range.to_offset(&snapshot); - let context_range = context_range.to_offset(&snapshot); - match args.provider { - PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { + PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => { step_progress.set_substatus("formatting teacher prompt"); + + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + edit_prediction::zeta2::max_editable_tokens(version), + edit_prediction::zeta2::MAX_CONTEXT_TOKENS, + ); + let editable_range = editable_range.to_offset(&snapshot); + let context_range = context_range.to_offset(&snapshot); + let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range); example.prompt = Some(ExamplePrompt { input: prompt, @@ -72,6 +73,15 @@ pub async fn run_format_prompt( PredictionProvider::Zeta2(version) => { step_progress.set_substatus("formatting zeta2 prompt"); + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + edit_prediction::zeta2::max_editable_tokens(version), + edit_prediction::zeta2::MAX_CONTEXT_TOKENS, + ); + let editable_range = editable_range.to_offset(&snapshot); + let context_range = context_range.to_offset(&snapshot); + let context_start = context_range.start; let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start; let editable_range_in_excerpt = diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index b8954f92745992839234cab73142278a176bc955..13a8399b3fe9853aaaa305355f3329170cf399fd 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -195,8 +195,8 @@ enum PredictionProvider { Mercury, Zeta1, Zeta2(ZetaVersion), - Teacher, - TeacherNonBatching, + Teacher(ZetaVersion), + TeacherNonBatching(ZetaVersion), } impl Default for PredictionProvider { @@ -212,8 +212,10 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"), - PredictionProvider::Teacher => write!(f, "teacher"), - PredictionProvider::TeacherNonBatching => write!(f, "teacher-non-batching"), + PredictionProvider::Teacher(version) => write!(f, "teacher:{version}"), + PredictionProvider::TeacherNonBatching(version) => { + write!(f, "teacher-non-batching:{version}") + } } } } @@ -221,29 +223,31 @@ impl std::fmt::Display for PredictionProvider { impl std::str::FromStr for PredictionProvider { type Err = anyhow::Error; - fn from_str(s: &str) -> Result { + fn from_str(mut s: &str) -> Result { + let mut version = ZetaVersion::default(); + if let Some((first, second)) = s.split_once(':') { + version = ZetaVersion::parse(second)?; + s = first; + } + let s_lower = s.to_lowercase(); match s_lower.as_str() { "sweep" => Ok(PredictionProvider::Sweep), "mercury" => Ok(PredictionProvider::Mercury), "zeta1" => Ok(PredictionProvider::Zeta1), - // Handle both old format "zeta2" and new format with version - "zeta2" => Ok(PredictionProvider::Zeta2(ZetaVersion::default())), - "teacher" => Ok(PredictionProvider::Teacher), + "zeta2" => Ok(PredictionProvider::Zeta2(version)), + "teacher" => Ok(PredictionProvider::Teacher(version)), "teacher-non-batching" | "teacher_non_batching" | "teachernonbatching" => { - Ok(PredictionProvider::TeacherNonBatching) + Ok(PredictionProvider::TeacherNonBatching(version)) } - _ if s_lower.starts_with("zeta2:") => { - let version_str = &s[6..]; - let version = ZetaVersion::parse(version_str)?; - Ok(PredictionProvider::Zeta2(version)) - } - _ => anyhow::bail!( - "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching\n\ + _ => { + anyhow::bail!( + "unknown provider `{s}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher-non-batching\n\ For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\ Available zeta versions:\n{}", - ZetaVersion::options_as_string() - ), + ZetaVersion::options_as_string() + ) + } } } } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index be22da635b320407befc46f04748f18c874294ac..17ff5347561c429a3d66987ce27a9f62c2506cae 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -20,6 +20,7 @@ use std::{ atomic::{AtomicUsize, Ordering::SeqCst}, }, }; +use zeta_prompt::ZetaVersion; static ANTHROPIC_CLIENT: OnceLock = OnceLock::new(); @@ -42,10 +43,9 @@ pub async fn run_prediction( run_context_retrieval(example, app_state.clone(), cx.clone()).await?; - if matches!( - provider, - PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching - ) { + if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) = + args.provider + { let _step_progress = Progress::global().start(Step::Predict, &example.spec.name); run_format_prompt( @@ -56,8 +56,8 @@ pub async fn run_prediction( ) .await?; - let batched = matches!(provider, PredictionProvider::Teacher); - return predict_anthropic(example, repetition_count, batched).await; + let batched = matches!(provider, PredictionProvider::Teacher(..)); + return predict_anthropic(example, repetition_count, version, batched).await; } run_load_project(example, app_state.clone(), cx.clone()).await?; @@ -96,7 +96,7 @@ pub async fn run_prediction( } PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, - PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => { + PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => { unreachable!() } }; @@ -246,6 +246,7 @@ pub async fn run_prediction( async fn predict_anthropic( example: &mut Example, _repetition_count: usize, + version: ZetaVersion, batched: bool, ) -> anyhow::Result<()> { let llm_model_name = "claude-sonnet-4-5"; @@ -292,7 +293,11 @@ async fn predict_anthropic( let prediction = ExamplePrediction { actual_patch, actual_output, - provider: PredictionProvider::Teacher, + provider: if batched { + PredictionProvider::Teacher(version) + } else { + PredictionProvider::TeacherNonBatching(version) + }, }; example.predictions.push(prediction); @@ -301,7 +306,7 @@ async fn predict_anthropic( pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> { match provider { - PredictionProvider::Teacher => { + PredictionProvider::Teacher(..) => { let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) .expect("Failed to create Anthropic client") diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs index 35fcb0d02453567b60f0ed95292b8c418a0da40a..9fa672d85a814d5d089f0e2147d72c69af6a1da1 100644 --- a/crates/zeta_prompt/src/zeta_prompt.rs +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -34,8 +34,9 @@ pub struct ZetaPromptInput { #[allow(non_camel_case_types)] pub enum ZetaVersion { V0112_MiddleAtEnd, - #[default] V0113_Ordered, + #[default] + V0114_180EditableRegion, } impl std::fmt::Display for ZetaVersion { @@ -72,10 +73,6 @@ impl ZetaVersion { .collect::>() .concat() } - - pub fn default_as_string() -> String { - <&'static str>::from(Self::default()).to_string() - } } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -140,7 +137,7 @@ pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> Stri ZetaVersion::V0112_MiddleAtEnd => { v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input); } - ZetaVersion::V0113_Ordered => { + ZetaVersion::V0113_Ordered | ZetaVersion::V0114_180EditableRegion => { v0113_ordered::write_cursor_excerpt_section(&mut prompt, input) } }